You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Blackwell] Add support for mixed precision scaled dot (triton-lang#5799)
Building on triton-lang#5786. The main
change is the representation of RHS in `mxfp8 x mxfp4`, which needs to
be in the special layout for Blackwell as described in
https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
A new feature in TMA can automatically store into such layout, but this
PR does not rely on TMA. Instead, this layout is represented via LL and
`ld.shared` or `cp.async` is used to manually create this layout in
SMEM.
Integration of this layout in the lowering pipeline turned out to be
very simple. After adding 64 bits padding as described above, we need to
apply swizzling on top of it. To support the new, "padded and swizzled"
layout, we just need to add a few steps to take into account the padding
in `sharedToLinearLayoutLeadingOffset`. This function can then be
considered as going through the steps `Padded, swizzled offset` ->
`(row, unswizzled but padded column)` -> `(row, unswizzled and packed
column)`.
Unlike `mxfp4 x mxfp4` case, Blackwell mixed precision supports
row-major RHS. In this case, the HW expects that the N axis to be packed
- packing is always done on the contiguous axis. This was experimentally
confirmed in my TMA-based branch, but obvious in hindsight because TMA
is not aware of K or N axis but it supports automatic padding on the
packed axis. However, Triton requires that padding is always done on the
K axis. This PR supports row-major RHS functionally, by forcing the RHS
SMEM order to be column-major and doing transpose before SMEM store if
the register layout is row-major. I also needed to disable pipelining
RHS load in that case, because `cp.async` requires at least 4 bytes
contiguity which is not satisfied when the on-the-fly transpose is
needed.
@ThomasRaoux@lezcano
---------
Co-authored-by: Masahiro Masuda <[email protected]>
// fp4Padded: Indicates that this encoding represents a mixed-precision fp4 operand in MMAv5 scaled dot, which needs
424
+
// to be in the special padded layout as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
422
425
let parameters = (
423
426
ins
424
427
"unsigned":$swizzlingByteWidth,
425
428
"bool":$transposed,
426
429
"unsigned":$elementBitWidth,
430
+
"bool":$fp4Padded,
427
431
"CTALayoutAttr":$CTALayout
428
432
);
429
433
430
434
let builders = [
431
435
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
432
436
"ArrayRef<unsigned>":$order,
433
437
"CTALayoutAttr":$CTALayout,
434
-
"Type":$eltTy), [{
438
+
"Type":$eltTy,
439
+
"bool": $fp4Padded), [{
435
440
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
0 commit comments