Skip to content

Commit 413b521

Browse files
authored
[Feature] Support different packing formats in dot_scaled op (#6420)
Add a new flag to pick the packing format for fp4 inputs in dot_scaled op. This allows supporting hardware fp4 transpose in blackwell by going through the mixed mode mxfp4 * mxfp8 message. Using this path will prevent getting fullt throughput when using dot_scaled mxfp4 * mxfp4.
1 parent 0209c69 commit 413b521

File tree

18 files changed

+303
-115
lines changed

18 files changed

+303
-115
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,23 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
5353
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5454
/*retType=*/"bool",
5555
/*methodName=*/"verifyDims",
56-
/*args=*/(ins)>
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
58+
/*desc=*/"Verify the dimensions of the DotOp output.",
59+
/*retType=*/"bool",
60+
/*methodName=*/"verifyOutputDims",
61+
/*args=*/(ins),
62+
/*methodBody=*/[{}],
63+
/*defaultImpl=*/ [{
64+
auto aTy = cast<ShapedType>($_op.getA().getType());
65+
auto bTy = cast<ShapedType>($_op.getB().getType());
66+
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
67+
auto aShape = aTy.getShape();
68+
auto bShape = bTy.getShape();
69+
auto cShape = cTy.getShape();
70+
return cShape[cShape.size() - 2] == aShape[aShape.size() - 2] &&
71+
cShape[cShape.size() - 1] == bShape[aShape.size() - 1];
72+
}]>
5773
];
5874

5975
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
676676
//
677677
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
678678
AttrSizedOperandSegments,
679-
DeclareOpInterfaceMethods<DotOpInterface>,
679+
DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
680680
TypesMatchWith<"result's type matches accumulator's type",
681681
"d", "c", "$_self">]> {
682682
let summary = "dot_scaled";
@@ -697,7 +697,9 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
697697
Optional<RankedTensorOf<[TT_Float, I8]>>:$b_scale,
698698
TT_ScaleDotElemTypeAttr:$a_elem_type,
699699
TT_ScaleDotElemTypeAttr:$b_elem_type,
700-
BoolAttr:$fastMath
700+
BoolAttr:$fastMath,
701+
DefaultValuedAttr<BoolAttr, "true">:$lhs_k_pack,
702+
DefaultValuedAttr<BoolAttr, "true">:$rhs_k_pack
701703
);
702704

703705
let results = (outs TT_FloatTensor:$d);

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
404404
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
405405
}
406406

407-
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
407+
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
408408
let summary = "block level op mapping to tensorcore gen5 mma";
409409

410410
let description = [{
@@ -423,7 +423,11 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
423423
I1:$useD,
424424
I1:$pred,
425425
Optional<TTG_MemDescType>:$barrier);
426-
426+
let extraClassDeclaration = [{
427+
int64_t getBlockM();
428+
int64_t getBlockN();
429+
int64_t getBlockK();
430+
}];
427431
// TODO: improve printing format.
428432
let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
429433
}

lib/Dialect/Triton/IR/OpInterfaces.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ LogicalResult verifyDotOpInterface(Operation *op) {
6464
"operand to be equal to the first dimension of "
6565
"the result");
6666
// Check the output shape
67-
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
68-
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
67+
if (!dotOp.verifyOutputDims())
6968
return dotOp->emitOpError(
7069
"expected the output shape to be the concatenation of the last "
7170
"dimension of the first operand and the last dimension of the "

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,39 @@ bool DotScaledOp::verifyDims() {
323323

324324
auto aKdim = aShape[aShape.size() - 1];
325325
auto bKdim = bShape[aShape.size() - 2];
326-
if (this->getAElemType() == ScaleDotElemType::E2M1)
327-
aKdim *= 2;
328-
if (this->getBElemType() == ScaleDotElemType::E2M1)
329-
bKdim *= 2;
326+
if (this->getAElemType() == ScaleDotElemType::E2M1) {
327+
if (this->getLhsKPack())
328+
aKdim *= 2;
329+
}
330+
if (this->getBElemType() == ScaleDotElemType::E2M1) {
331+
if (this->getRhsKPack())
332+
bKdim *= 2;
333+
}
330334

331335
return aKdim == bKdim;
332336
}
333337

338+
bool DotScaledOp::verifyOutputDims() {
339+
auto cShape = this->getC().getType().getShape();
340+
auto oMdim = cShape[cShape.size() - 2];
341+
auto oNdim = cShape[cShape.size() - 1];
342+
auto aShape = this->getA().getType().getShape();
343+
auto bShape = this->getB().getType().getShape();
344+
auto adim = aShape[aShape.size() - 2];
345+
auto bdim = bShape[bShape.size() - 1];
346+
if (this->getAElemType() == ScaleDotElemType::E2M1) {
347+
if (!this->getLhsKPack())
348+
adim *= 2;
349+
}
350+
if (this->getBElemType() == ScaleDotElemType::E2M1) {
351+
if (!this->getRhsKPack())
352+
bdim *= 2;
353+
}
354+
if (adim != oMdim || bdim != oNdim)
355+
return false;
356+
return true;
357+
}
358+
334359
//-- MakeRangeOp --
335360
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
336361
// make_range(start, start + 1) -> constant(start)

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
149149
static Value
150150
getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
151151
bool allowTranspose, bool isMMAv5Fp4Padded = false,
152+
bool forceTranspose = false,
152153
Operation *op = nullptr /*only for diagnostic*/) {
153154
OpBuilder::InsertionGuard g(rewriter);
154155
Value arg = v;
@@ -167,6 +168,8 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
167168
} else {
168169
newOrder = {1, 0};
169170
}
171+
if (forceTranspose)
172+
std::swap(newOrder[0], newOrder[1]);
170173
}
171174

172175
if (newOrder != order && op) {
@@ -648,49 +651,47 @@ class ScaledBlockedToMMAv5
648651

649652
bool IsAMixedPrecFp4 = false;
650653
bool IsBMixedPrecFp4 = false;
654+
bool isAFP4 = dotOp.getAElemType() == ScaleDotElemType::E2M1;
655+
bool isBFP4 = dotOp.getBElemType() == ScaleDotElemType::E2M1;
651656

652657
if (dotOp.getAElemType() != dotOp.getBElemType()) {
653-
if (dotOp.getAElemType() == ScaleDotElemType::E2M1)
658+
if (isAFP4)
654659
IsAMixedPrecFp4 = true;
655-
else if (dotOp.getBElemType() == ScaleDotElemType::E2M1)
660+
else if (isBFP4)
656661
IsBMixedPrecFp4 = true;
657662
}
658-
663+
// If we use txgen05.mma.kind.mxf864 we need to padd the fp4 operands:
664+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-packing-formats-mxf8f6f4-smem
665+
bool isMMAv5Fp4PaddedLhs = IsAMixedPrecFp4 || !dotOp.getLhsKPack();
666+
bool isMMAv5Fp4PaddedRhs = IsBMixedPrecFp4 || !dotOp.getRhsKPack();
659667
// For mixed-precision fp4 operands, set allowTranspose = false, to force
660668
// the packed axis, K, to be contiguous in SMEM
661669
a = getSharedMemoryMMAOperand(a, rewriter, 0,
662-
/*allowTranspose=*/!IsAMixedPrecFp4,
663-
IsAMixedPrecFp4, dotOp);
670+
/*allowTranspose=*/!isAFP4,
671+
/*isMMAv5Fp4Padded=*/isMMAv5Fp4PaddedLhs,
672+
/*forceTranspose=*/!dotOp.getLhsKPack(),
673+
dotOp);
664674
b = getSharedMemoryMMAOperand(b, rewriter, 1,
665-
/*allowTranspose=*/!IsBMixedPrecFp4,
666-
IsBMixedPrecFp4, dotOp);
675+
/*allowTranspose=*/!isBFP4,
676+
/*isMMAv5Fp4Padded=*/isMMAv5Fp4PaddedRhs,
677+
/*forceTranspose=*/!dotOp.getRhsKPack(),
678+
dotOp);
667679

668680
MLIRContext *context = dotOp->getContext();
669681
unsigned m = 128;
670682
unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1];
671-
unsigned k = 32;
672-
// If both operands are E2M1, target the FP4 tensor core implicitly.
673-
// This may result in a downstream compile-time error if the scaled TC
674-
// descriptor requires options that are unavailable to the .kind=mxf4 mma.
675-
// This is likely preferable over a silent runtime performance degradation
676-
// from running f4xf4 via .kind=mxf8f6f4
677-
if (dotOp.getAElemType() == ScaleDotElemType::E2M1 &&
678-
dotOp.getBElemType() == ScaleDotElemType::E2M1) {
679-
k = 64;
680-
}
681-
SmallVector<unsigned> instrShape = {m, n, k};
683+
682684
ArrayRef<unsigned> CTASplitNum = CTALayout.getCTASplitNum();
683685
Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
684-
context, instrShape[0], instrShape[1], /*unpacked=*/true,
685-
CTASplitNum[0], CTASplitNum[1]);
686+
context, m, n, /*unpacked=*/true, CTASplitNum[0], CTASplitNum[1]);
686687
Attribute tensorMemorySpace =
687688
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
688689
Type accMemDescType = triton::gpu::MemDescType::get(
689690
oldRetType.getShape(), oldRetType.getElementType(), accEncoding,
690691
tensorMemorySpace,
691692
/*mutableMemory=*/true);
692-
Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout(
693-
instrShape[0], instrShape[1], oldRetType, numWarps);
693+
Attribute newDistributedEncoding =
694+
nvidia_gpu::getTmemCompatibleLayout(m, n, oldRetType, numWarps);
694695
auto newAccType = RankedTensorType::get(oldRetType.getShape(),
695696
oldRetType.getElementType(),
696697
newDistributedEncoding);

lib/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
3030

3131
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
3232
PatternRewriter &rewriter) const override {
33+
// TODO: add support for m/n packed formats.
34+
if (!scaledDotOp.getLhsKPack() || !scaledDotOp.getRhsKPack())
35+
return failure();
3336
// Types
3437
auto computeType = getComputeType(scaledDotOp.getAElemType(),
3538
scaledDotOp.getBElemType(), rewriter);

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,16 +355,55 @@ bool TCGen5MMAScaledOp::verifyDims() {
355355
auto aShape = this->getA().getType().getShape();
356356
auto bShape = this->getB().getType().getShape();
357357

358+
bool transA = false;
359+
if (auto aSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
360+
getA().getType().getEncoding())) {
361+
transA = aSharedLayout.getTransposed();
362+
}
363+
bool transB = false;
364+
if (auto bSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
365+
getB().getType().getEncoding())) {
366+
transB = !bSharedLayout.getTransposed();
367+
}
358368
auto aKdim = aShape[aShape.size() - 1];
359369
auto bKdim = bShape[aShape.size() - 2];
360-
if (this->getAType() == ScaleDotElemType::E2M1)
370+
if (this->getAType() == ScaleDotElemType::E2M1 && !transA)
361371
aKdim *= 2;
362-
if (this->getBType() == ScaleDotElemType::E2M1)
372+
if (this->getBType() == ScaleDotElemType::E2M1 && !transB)
363373
bKdim *= 2;
364374

365375
return aKdim == bKdim;
366376
}
367377

378+
bool TCGen5MMAScaledOp::verifyOutputDims() {
379+
auto aShape = this->getA().getType().getShape();
380+
auto bShape = this->getB().getType().getShape();
381+
auto cShape = this->getD().getType().getShape();
382+
auto oMdim = cShape[cShape.size() - 2];
383+
auto oNdim = cShape[cShape.size() - 1];
384+
385+
int aMdim = aShape[aShape.size() - 2];
386+
int bNdim = bShape[bShape.size() - 1];
387+
bool transA = false;
388+
if (auto aSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
389+
getA().getType().getEncoding())) {
390+
transA = aSharedLayout.getTransposed();
391+
}
392+
bool transB = false;
393+
if (auto bSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
394+
getB().getType().getEncoding())) {
395+
transB = !bSharedLayout.getTransposed();
396+
}
397+
if (this->getAType() == ScaleDotElemType::E2M1 && transA)
398+
aMdim *= 2;
399+
if (this->getBType() == ScaleDotElemType::E2M1 && transB)
400+
bNdim *= 2;
401+
402+
if (aMdim != oMdim || bNdim != oNdim)
403+
return false;
404+
return true;
405+
}
406+
368407
Value TCGen5MMAScaledOp::useAccumulator() { return getUseD(); }
369408

370409
void TCGen5MMAScaledOp::setUseAccumulator(Value flag) {
@@ -387,6 +426,46 @@ void TCGen5MMAScaledOp::setPredicate(Value pred) {
387426
getPredMutable().assign(pred);
388427
}
389428

429+
int64_t TCGen5MMAScaledOp::getBlockM() {
430+
ArrayRef<int64_t> shape = getA().getType().getShape();
431+
int64_t blockM = shape[shape.size() - 2];
432+
bool transA = false;
433+
if (auto aSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
434+
getA().getType().getEncoding())) {
435+
transA = aSharedLayout.getTransposed();
436+
}
437+
if (this->getAType() == ScaleDotElemType::E2M1 && transA)
438+
blockM *= 2;
439+
return blockM;
440+
}
441+
442+
int64_t TCGen5MMAScaledOp::getBlockN() {
443+
ArrayRef<int64_t> shape = getB().getType().getShape();
444+
int64_t blockN = shape[shape.size() - 1];
445+
bool transB = false;
446+
if (auto bSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
447+
getB().getType().getEncoding())) {
448+
transB = !bSharedLayout.getTransposed();
449+
}
450+
if (this->getBType() == ScaleDotElemType::E2M1 && transB)
451+
blockN *= 2;
452+
return blockN;
453+
}
454+
455+
int64_t TCGen5MMAScaledOp::getBlockK() {
456+
ArrayRef<int64_t> shape = getA().getType().getShape();
457+
int64_t blockK = shape[shape.size() - 1];
458+
bool transA = false;
459+
if (auto aSharedLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(
460+
getA().getType().getEncoding())) {
461+
transA = aSharedLayout.getTransposed();
462+
}
463+
if (this->getAType() == ScaleDotElemType::E2M1 && !transA)
464+
blockK *= 2;
465+
return blockK;
466+
}
467+
468+
// -- TMEMLoadOp --
390469
// -- TMEMLoadOp --
391470
LogicalResult TMEMLoadOp::verify() {
392471
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(

lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,8 @@ struct TCGen5MMAScaleSharedToTmemConversion
8989
MLIRContext *context = op->getContext();
9090
auto aScaleType = op.getAScale().getType();
9191
auto bScaleType = op.getBScale().getType();
92-
int blockM = op.getA()
93-
.getType()
94-
.getShape()[op.getA().getType().getShape().size() - 2];
95-
int blockN = op.getB()
96-
.getType()
97-
.getShape()[op.getB().getType().getShape().size() - 1];
98-
int blockK = op.getA()
99-
.getType()
100-
.getShape()[op.getA().getType().getShape().size() - 1];
92+
int blockM = op.getBlockM();
93+
int blockN = op.getBlockN();
10194
bool anyChanged = false;
10295
if (isa<SwizzledSharedEncodingAttr>(aScaleType.getEncoding())) {
10396
anyChanged = lowerScaleToTmem(op.getAScaleMutable(), rewriter, blockM);

python/src/ir.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,12 +1593,12 @@ void init_triton_ir(py::module &&m) {
15931593
std::optional<mlir::Value> &lhs_scale,
15941594
ScaleDotElemType lhs_format, mlir::Value &rhs,
15951595
std::optional<mlir::Value> &rhs_scale,
1596-
ScaleDotElemType rhs_format, bool fast_math,
1597-
mlir::Value &c) -> mlir::Value {
1598-
return self.create<DotScaledOp>(c.getType(), lhs, rhs, c,
1599-
lhs_scale.value_or(Value()),
1600-
rhs_scale.value_or(Value()),
1601-
lhs_format, rhs_format, fast_math);
1596+
ScaleDotElemType rhs_format, bool fast_math, bool lhs_k_pack,
1597+
bool rhs_k_pack, mlir::Value &c) -> mlir::Value {
1598+
return self.create<DotScaledOp>(
1599+
c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()),
1600+
rhs_scale.value_or(Value()), lhs_format, rhs_format, fast_math,
1601+
lhs_k_pack, rhs_k_pack);
16021602
})
16031603
.def("create_floor",
16041604
[](TritonOpBuilder &self, Value &val) -> Value {

0 commit comments

Comments
 (0)