Skip to content

Commit 5d2a1d2

Browse files
masahimakslevental
authored andcommitted
[Blackwell] Add support for mixed precision scaled dot (triton-lang#5799)
Building on triton-lang#5786. The main change is the representation of RHS in `mxfp8 x mxfp4`, which needs to be in the special layout for Blackwell as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory A new feature in TMA can automatically store into such layout, but this PR does not rely on TMA. Instead, this layout is represented via LL and `ld.shared` or `cp.async` is used to manually create this layout in SMEM. Integration of this layout in the lowering pipeline turned out to be very simple. After adding 64 bits padding as described above, we need to apply swizzling on top of it. To support the new, "padded and swizzled" layout, we just need to add a few steps to take into account the padding in `sharedToLinearLayoutLeadingOffset`. This function can then be considered as going through the steps `Padded, swizzled offset` -> `(row, unswizzled but padded column)` -> `(row, unswizzled and packed column)`. Unlike `mxfp4 x mxfp4` case, Blackwell mixed precision supports row-major RHS. In this case, the HW expects that the N axis to be packed - packing is always done on the contiguous axis. This was experimentally confirmed in my TMA-based branch, but obvious in hindsight because TMA is not aware of K or N axis but it supports automatic padding on the packed axis. However, Triton requires that padding is always done on the K axis. This PR supports row-major RHS functionally, by forcing the RHS SMEM order to be column-major and doing transpose before SMEM store if the register layout is row-major. I also needed to disable pipelining RHS load in that case, because `cp.async` requires at least 4 bytes contiguity which is not satisfied when the on-the-fly transpose is needed. @ThomasRaoux @lezcano --------- Co-authored-by: Masahiro Masuda <[email protected]>
1 parent e1bb2cc commit 5d2a1d2

File tree

20 files changed

+432
-73
lines changed

20 files changed

+432
-73
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ class SharedMemoryObject {
427427
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
428428
RewriterBase &rewriter) const {
429429
auto allocShape = memDesc.getAllocShape();
430-
auto allocShapePerCTA =
431-
triton::gpu::getShapePerCTA(memDesc.getEncoding(), allocShape);
430+
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
431+
memDesc.getEncoding(), allocShape);
432432
auto layoutOrder = triton::gpu::getOrder(memDesc.getEncoding());
433433
auto allocStrides = SharedMemoryObject::getStridesForShape(
434434
allocShapePerCTA, layoutOrder, loc, rewriter);

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,19 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
165165
*/
166166
SmallVector<unsigned> getShapePerCTATile(Attribute layout);
167167

168+
// Returns the "logical" shape per CTA
168169
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
169170
ArrayRef<int64_t> shape);
170171
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
171172
SmallVector<int64_t> getShapePerCTA(Type type);
172173

174+
// Returns the shape per CTA, which is "physically" allocated
175+
// Such shapes may be bigger than the logical one due to, for example, padding
176+
// in shared memory.
177+
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
178+
ArrayRef<int64_t> shape);
179+
SmallVector<int64_t> getAllocationShapePerCTA(Type type);
180+
173181
unsigned getNumWarpsPerCTA(Attribute layout);
174182

175183
unsigned getNumCTAs(Attribute layout);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,26 +419,32 @@ def NVMMASharedEncodingAttr :
419419
https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout
420420
}];
421421

422+
423+
// fp4Padded: Indicates that this encoding represents a mixed-precision fp4 operand in MMAv5 scaled dot, which needs
424+
// to be in the special padded layout as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
422425
let parameters = (
423426
ins
424427
"unsigned":$swizzlingByteWidth,
425428
"bool":$transposed,
426429
"unsigned":$elementBitWidth,
430+
"bool":$fp4Padded,
427431
"CTALayoutAttr":$CTALayout
428432
);
429433

430434
let builders = [
431435
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
432436
"ArrayRef<unsigned>":$order,
433437
"CTALayoutAttr":$CTALayout,
434-
"Type":$eltTy), [{
438+
"Type":$eltTy,
439+
"bool": $fp4Padded), [{
435440
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
436441
int32_t swizzlingByteWidth = 0;
437442
unsigned eleBitWidth = eltTy.getIntOrFloatBitWidth();
443+
int packingFactor = fp4Padded ? 2 : 1;
438444

439445
// get proper shared memory swizzling mode from the contiguous dimension
440446
// size of the origin blocked layout.
441-
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
447+
auto contigDimSizeInByte = shapePerCTA[order[0]] * packingFactor * eleBitWidth / 8;
442448
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
443449
swizzlingByteWidth = 128;
444450
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
@@ -449,7 +455,7 @@ def NVMMASharedEncodingAttr :
449455
llvm_unreachable("unsupported shared memory layout for MMAv3");
450456
}
451457
bool transposed = order[0] == 0;
452-
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, CTALayout);
458+
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CTALayout);
453459
}]>
454460
];
455461

include/triton/Tools/LinearLayout.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,9 +683,8 @@ class LinearLayout {
683683
// Otherwise, R could map some tensor index that is not stored in S.
684684
//
685685
// One requirement we *don't* have is that S is injective; we allow two shmem
686-
// offsets to hold the same 2D index. If S is not injective, there's
687-
// ambiguity in which offset we choose for a given (lane, warp). For now we
688-
// don't place any guarantees on the choices made by this function.
686+
// offsets to hold the same 2D index. If S is not injective,
687+
// the algorithm chooses the smallest offset for a given (lane, warp).
689688
[[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const;
690689

691690
// Get the layout that is the inverse of this layout.

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class AllocationAnalysis {
234234
// Bytes could be a different value once we support padding or other
235235
// allocation policies.
236236
auto allocType = alloc.getType();
237-
auto shapePerCTA = gpu::getShapePerCTA(allocType);
237+
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
238238
auto bytes = product<int64_t>(shapePerCTA) *
239239
allocType.getElementTypeBitWidth() / 8;
240240

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,34 @@ SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape) {
388388
return getShapePerCTA(splitNum, shape);
389389
}
390390

391+
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
392+
ArrayRef<int64_t> shapeLogical) {
393+
SmallVector<int64_t> shape(shapeLogical);
394+
if (auto sharedMMALayout = mlir::dyn_cast<NVMMASharedEncodingAttr>(layout)) {
395+
if (sharedMMALayout.getFp4Padded()) {
396+
auto packedAxis = getOrder(sharedMMALayout)[0];
397+
if (shape.size() == 3) {
398+
// Take into account multi buffering
399+
shape[1 + packedAxis] *= 2;
400+
} else {
401+
shape[packedAxis] *= 2;
402+
}
403+
}
404+
}
405+
return getShapePerCTA(layout, shape);
406+
}
407+
391408
SmallVector<int64_t> getShapePerCTA(Type type) {
392409
auto tensorType = cast<TensorOrMemDesc>(type);
393410
return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape());
394411
}
395412

413+
SmallVector<int64_t> getAllocationShapePerCTA(Type type) {
414+
auto tensorType = cast<TensorOrMemDesc>(type);
415+
return getAllocationShapePerCTA(tensorType.getEncoding(),
416+
tensorType.getShape());
417+
}
418+
396419
unsigned getNumWarpsPerCTA(Attribute layout) {
397420
SmallVector<unsigned> warpsPerCTA;
398421
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout))
@@ -1913,7 +1936,8 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19131936
return {};
19141937

19151938
unsigned swizzlingByteWidth;
1916-
bool transposed;
1939+
bool transposed = false;
1940+
bool fp4Padded = false;
19171941
unsigned elementBitWidth;
19181942
std::optional<SmallVector<unsigned>> CTAsPerCGA;
19191943
std::optional<SmallVector<unsigned>> CTASplitNum;
@@ -1929,6 +1953,9 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19291953
} else if (attr.getName() == "elementBitWidth") {
19301954
if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed())
19311955
return {};
1956+
} else if (attr.getName() == "fp4Padded") {
1957+
if (parseBool(parser, attr, fp4Padded, "fp4Padded").failed())
1958+
return {};
19321959
} else if (attr.getName() == "CTAsPerCGA") {
19331960
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
19341961
.failed())
@@ -1955,14 +1982,18 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19551982

19561983
return parser.getChecked<NVMMASharedEncodingAttr>(
19571984
parser.getContext(), swizzlingByteWidth, transposed, elementBitWidth,
1958-
*CTALayout);
1985+
fp4Padded, *CTALayout);
19591986
}
19601987

19611988
void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const {
19621989
printer << "<{"
19631990
<< "swizzlingByteWidth = " << getSwizzlingByteWidth() //
19641991
<< ", transposed = " << getTransposed() //
19651992
<< ", elementBitWidth = " << getElementBitWidth();
1993+
if (getFp4Padded()) {
1994+
// Print only in this case to reduce the noise for the more common case.
1995+
printer << ", fp4Padded = true";
1996+
}
19661997
maybePrintCTALayout(getContext(), printer, getCTALayout(),
19671998
/*rank=*/2);
19681999
printer << "}>";
@@ -2602,7 +2633,7 @@ struct TritonGPUInferLayoutInterface
26022633
}
26032634
resultEncoding = NVMMASharedEncodingAttr::get(
26042635
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
2605-
enc.getElementBitWidth(), *ctaLayout);
2636+
enc.getElementBitWidth(), enc.getFp4Padded(), *ctaLayout);
26062637
return success();
26072638
}
26082639

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,16 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
202202

203203
int tileRows = 8;
204204
int tileCols = 8 * tileWidthBytes / elemBitWidth;
205+
bool isFp4Padded = false;
206+
if (auto sharedMMALayout =
207+
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(shared)) {
208+
if (sharedMMALayout.getFp4Padded()) {
209+
isFp4Padded = true;
210+
}
211+
}
212+
int packingFactor = isFp4Padded ? 2 : 1;
205213

206-
if (shape[colDim] < tileCols || shape[rowDim] < tileRows) {
214+
if (shape[colDim] * packingFactor < tileCols || shape[rowDim] < tileRows) {
207215
llvm::errs() << "Illegal shared layout; expected shape to be at least ["
208216
<< tileRows << ", " << tileCols << "], shape: ["
209217
<< shape[rowDim] << ", " << shape[colDim] << "]\n";
@@ -215,15 +223,32 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
215223

216224
std::vector<std::vector<int>> bases2D;
217225
for (int logCol = 0; logCol < llvm::Log2_32(tileCols); logCol++) {
218-
bases2D.push_back({0, 1 << logCol});
226+
if (isFp4Padded) {
227+
int colPadded = 1 << logCol;
228+
// Each group of 16 offsets consists of 8 "real" and 8 "padded" offsets.
229+
// We represent the padded layout by mapping 8 padded offsets to the same
230+
// coordinates as the real ones. When computing the inverse of this LL,
231+
// the offsets correspoding to the real ones are picked in the image by
232+
// invertAndCompose.
233+
int colPacked = colPadded / 16 * 8 + colPadded % 8;
234+
bases2D.push_back({0, colPacked});
235+
} else {
236+
bases2D.push_back({0, 1 << logCol});
237+
}
219238
}
220239
for (int logRow = 0; logRow < llvm::Log2_32(tileRows); logRow++) {
221240
int row = 1 << logRow;
222241
if (disableSwizzle) {
223242
bases2D.push_back({row, 0});
224243
continue;
225244
}
226-
bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)});
245+
if (isFp4Padded) {
246+
int colPadded = vec * ((row / perPhase) % maxPhase);
247+
int colPacked = colPadded / 16 * 8 + colPadded % 8;
248+
bases2D.push_back({row, colPacked});
249+
} else {
250+
bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)});
251+
}
227252
}
228253
LinearLayout tileLayout =
229254
LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName});

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,10 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
141141

142142
// Returns a shared memory allocation that can be used by a dotMMA op for the
143143
// given value.
144-
static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
145-
int opIdx, bool allowTranspose) {
144+
static Value
145+
getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
146+
bool allowTranspose, bool isMMAv5Fp4Padded = false,
147+
Operation *op = nullptr /*only for diagnostic*/) {
146148
OpBuilder::InsertionGuard g(rewriter);
147149
Value arg = v;
148150
if (auto cvtOp = v.getDefiningOp<ConvertLayoutOp>())
@@ -161,12 +163,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
161163
}
162164
}
163165

166+
if (newOrder != getOrder(argType.getEncoding()) && op) {
167+
op->emitWarning("Warning: Forcing a different order [")
168+
<< newOrder[0] << ", " << newOrder[1]
169+
<< "] on SMEM than the register order for the opreand " << opIdx
170+
<< ". Registers will be transposed before SMEM store and the pipelined "
171+
"load for this operand will be disabled, so poor performance is "
172+
"expected.";
173+
}
174+
164175
Attribute SharedMemorySpace =
165176
SharedMemorySpaceAttr::get(argType.getContext());
166177
auto CTALayout = getCTALayout(argType.getEncoding());
167178
auto newLayout = NVMMASharedEncodingAttr::get(
168179
argType.getContext(), argType.getShape(), newOrder, CTALayout,
169-
argType.getElementType());
180+
argType.getElementType(), isMMAv5Fp4Padded);
170181
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
171182
newLayout, SharedMemorySpace);
172183
rewriter.setInsertionPointAfterValue(arg);
@@ -582,11 +593,6 @@ class ScaledBlockedToMMAv5
582593
mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
583594
return failure();
584595

585-
if (dotOp.getLhsType() != dotOp.getRhsType()) {
586-
// Mixed precision is not supported yet.
587-
return failure();
588-
}
589-
590596
if (dotOp.getLhsScale() == nullptr || dotOp.getRhsScale() == nullptr) {
591597
return failure();
592598
}
@@ -607,8 +613,25 @@ class ScaledBlockedToMMAv5
607613
auto oldAType = dotOp.getLhs().getType();
608614
auto oldBType = dotOp.getRhs().getType();
609615

610-
a = getSharedMemoryMMAOperand(a, rewriter, 0, /*allowTranspose=*/true);
611-
b = getSharedMemoryMMAOperand(b, rewriter, 1, /*allowTranspose=*/true);
616+
bool IsAMixedPrecFp4 = false;
617+
bool IsBMixedPrecFp4 = false;
618+
619+
if (dotOp.getLhsType() != dotOp.getRhsType()) {
620+
if (dotOp.getLhsType() == ScaleDotElemType::E2M1)
621+
IsAMixedPrecFp4 = true;
622+
else if (dotOp.getRhsType() == ScaleDotElemType::E2M1)
623+
IsBMixedPrecFp4 = true;
624+
}
625+
626+
// For mixed-precision fp4 operands, set allowTranspose = false, to force
627+
// the packed axis, K, to be contiguous in SMEM
628+
a = getSharedMemoryMMAOperand(a, rewriter, 0,
629+
/*allowTranspose=*/!IsAMixedPrecFp4,
630+
IsAMixedPrecFp4, dotOp);
631+
b = getSharedMemoryMMAOperand(b, rewriter, 1,
632+
/*allowTranspose=*/!IsBMixedPrecFp4,
633+
IsBMixedPrecFp4, dotOp);
634+
612635
MLIRContext *context = dotOp->getContext();
613636
unsigned m = 128;
614637
unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1];

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
125125
// all CTALayouts are the same.
126126
auto newInnerEnc = NVMMASharedEncodingAttr::get(
127127
getContext(), srcTy.getShape(), newInnerCvtOrder,
128-
allocEncoding.getCTALayout(), srcTy.getElementType());
128+
allocEncoding.getCTALayout(), srcTy.getElementType(),
129+
allocEncoding.getFp4Padded());
129130

130131
MemDescType innerTy =
131132
MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc,

0 commit comments

Comments
 (0)