Skip to content

Commit 72ec661

Browse files
authored
[Blackwell] Add rewrite pattern to merge explicit commit ops into MMAv5 (#8026)
An explicit commit op and the barrier arguments to MMAv5 express the same synchronization patterns. In practice, however, emitting `tcgen05.mma` and `tcgen05.commit` together into the same basic block, as implemented in `MMAv5.cpp` for MMAv5 ops with the barrier arguments, can yield slightly more efficient SASS sometimes. This PR adds a simple rewrite pattern that folds explicit commit ops into a preceding MMAv5 op, so that use cases involving explicit commit can also benefit from the low-level optimization. More discussions that motivated this work are in triton-lang/triton#7984.
1 parent 37888a9 commit 72ec661

File tree

4 files changed

+182
-9
lines changed

4 files changed

+182
-9
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
5353
"void",
5454
"setIsAsync",
5555
(ins "bool":$isAsync)>,
56+
InterfaceMethod<"Return true if this MMA op executes asynchronously.",
57+
"bool",
58+
"isAsync">
5659
];
5760

5861
let verify = [{

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ void TCGen5MMAOp::build(OpBuilder &builder, OperationState &state, Type token,
400400
useTwoCTAs ? builder.getUnitAttr() : UnitAttr());
401401
}
402402

403+
bool TCGen5MMAOp::isAsync() { return getIsAsync(); }
404+
403405
// -- TCGen5MMAScaledOp --
404406
LogicalResult TCGen5MMAScaledOp::verify() {
405407
if (!getIsAsync() && !getBarriers().empty()) {
@@ -573,6 +575,8 @@ void TCGen5MMAScaledOp::build(OpBuilder &builder, OperationState &state,
573575
barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr());
574576
}
575577

578+
bool TCGen5MMAScaledOp::isAsync() { return getIsAsync(); }
579+
576580
// -- TMEMStoreOp --
577581
static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type,
578582
MemDescType memdesc, StringRef regName) {

lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
12
#include "mlir/Support/LogicalResult.h"
23
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34
#include "mlir/Transforms/Passes.h"
@@ -16,15 +17,14 @@ namespace nvidia_gpu {
1617

1718
namespace {
1819

19-
template <typename TCGen5MMAOpTy>
20-
class SyncMMALowering : public OpRewritePattern<TCGen5MMAOpTy> {
20+
class SyncMMALowering : public OpInterfaceRewritePattern<MMAv5OpInterface> {
2121
public:
22-
using OpRewritePattern<TCGen5MMAOpTy>::OpRewritePattern;
22+
using OpInterfaceRewritePattern<MMAv5OpInterface>::OpInterfaceRewritePattern;
2323

24-
LogicalResult matchAndRewrite(TCGen5MMAOpTy op,
24+
LogicalResult matchAndRewrite(MMAv5OpInterface op,
2525
PatternRewriter &rewriter) const override {
2626
// If the op doesn't have synchronous semantic skip the pattern.
27-
if (op.getIsAsync())
27+
if (op.isAsync())
2828
return failure();
2929
MLIRContext *ctx = op.getContext();
3030
Location loc = op.getLoc();
@@ -46,7 +46,7 @@ class SyncMMALowering : public OpRewritePattern<TCGen5MMAOpTy> {
4646

4747
rewriter.setInsertionPointAfter(op);
4848
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
49-
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase, op.getPred());
49+
rewriter.create<WaitBarrierOp>(loc, barrierAlloc, phase, op.getPredicate());
5050
rewriter.create<InvalBarrierOp>(loc, barrierAlloc);
5151
return success();
5252
}
@@ -101,6 +101,103 @@ struct TCGen5MMAScaleSharedToTmemConversion
101101
}
102102
};
103103

104+
std::pair<SmallVector<TCGen5CommitOp>, SmallVector<Value>>
105+
collectCommitOpsAfter(MMAv5OpInterface mmaOp) {
106+
auto isConstTrue = [](Value v) {
107+
if (auto constOp = v.getDefiningOp<arith::ConstantOp>()) {
108+
if (auto attr = dyn_cast<BoolAttr>(constOp.getValueAttr())) {
109+
return attr.getValue();
110+
}
111+
}
112+
return false;
113+
};
114+
115+
SmallVector<TCGen5CommitOp> commitOps;
116+
SmallVector<Value> commitPredicates;
117+
auto mmaPred = mmaOp.getPredicate();
118+
Operation *nextOp = mmaOp->getNextNode();
119+
120+
while (nextOp) {
121+
if (auto commit = dyn_cast<TCGen5CommitOp>(nextOp)) {
122+
// If the mma predicate is true, or mma and commit ops use the same
123+
// predicate, it is safe to merge them
124+
if (isConstTrue(mmaPred) || mmaPred == commit.getPred()) {
125+
commitOps.push_back(commit);
126+
commitPredicates.push_back(commit.getPred());
127+
}
128+
} else if (!isPure(nextOp)) {
129+
// Only move commits across pure ops. We also bail here when encountering
130+
// another MMAv5 op.
131+
break;
132+
}
133+
nextOp = nextOp->getNextNode();
134+
}
135+
136+
return {commitOps, commitPredicates};
137+
}
138+
139+
// Return false if defining ops cannot be moved above the target op
140+
bool moveDefiningOpsBefore(Value val, Operation *target) {
141+
SetVector<Operation *> toMove;
142+
143+
std::function<bool(Value)> collectOpsToMove = [&](Value val) {
144+
if (auto defOp = val.getDefiningOp()) {
145+
if (defOp->getBlock() == target->getBlock() &&
146+
target->isBeforeInBlock(defOp)) {
147+
if (!isPure(defOp)) {
148+
// This defOp needs to move above the target op, but it is unsafe due
149+
// to impurity.
150+
return false;
151+
}
152+
for (Value operand : defOp->getOperands()) {
153+
if (!collectOpsToMove(operand)) {
154+
return false;
155+
}
156+
}
157+
toMove.insert(defOp);
158+
}
159+
}
160+
return true;
161+
};
162+
163+
if (!collectOpsToMove(val)) {
164+
return false;
165+
}
166+
167+
for (Operation *op : toMove) {
168+
op->moveBefore(target);
169+
}
170+
171+
return true;
172+
}
173+
174+
class MergeCommitIntoMMA : public OpInterfaceRewritePattern<MMAv5OpInterface> {
175+
public:
176+
using OpInterfaceRewritePattern<MMAv5OpInterface>::OpInterfaceRewritePattern;
177+
178+
LogicalResult matchAndRewrite(MMAv5OpInterface op,
179+
PatternRewriter &rewriter) const override {
180+
auto [commitOps, predicates] = collectCommitOpsAfter(op);
181+
if (commitOps.size() == 0) {
182+
return llvm::failure();
183+
}
184+
for (auto [commit, pred] : llvm::zip(commitOps, predicates)) {
185+
if (!pred) {
186+
pred = rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
187+
}
188+
if (!moveDefiningOpsBefore(commit.getBarrier(), op) ||
189+
!moveDefiningOpsBefore(pred, op)) {
190+
// Give up merging a commit if its defining ops cannot be moved above
191+
// the mma op.
192+
continue;
193+
}
194+
op.addCompletionBarrier(commit.getBarrier(), pred);
195+
rewriter.eraseOp(commit);
196+
}
197+
return success();
198+
}
199+
};
200+
104201
} // anonymous namespace
105202

106203
class TritonNvidiaGPUMMALoweringPass
@@ -112,9 +209,9 @@ class TritonNvidiaGPUMMALoweringPass
112209
ModuleOp m = getOperation();
113210

114211
mlir::RewritePatternSet patterns(context);
115-
patterns
116-
.add<SyncMMALowering<TCGen5MMAOp>, SyncMMALowering<TCGen5MMAScaledOp>,
117-
TCGen5MMAScaleSharedToTmemConversion>(context);
212+
patterns.add<SyncMMALowering, TCGen5MMAScaleSharedToTmemConversion,
213+
MergeCommitIntoMMA>(context);
214+
118215
if (applyPatternsGreedily(m, std::move(patterns)).failed())
119216
signalPassFailure();
120217
}

test/TritonNvidiaGPU/mma_lowering.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
5858
tt.return
5959
}
6060
}
61+
62+
// -----
63+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
64+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
65+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
66+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
67+
#smem = #ttg.shared_memory
68+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
69+
// CHECK-LABEL: tcgen5_with_commit
70+
tt.func @tcgen5_with_commit(
71+
// CHECK: [[BARRIER1:%.*]]: !ttg.memdesc<1xi64, #shared
72+
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>,
73+
// CHECK: [[BARRIER_PRED:%.*]]: i1,
74+
%barrierPred: i1,
75+
// CHECK: [[A_SMEM:%.*]]: !ttg.memdesc<128x128xf8E5M2
76+
%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
77+
%b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
78+
%c: !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>) {
79+
%barrier2 = ttg.local_alloc : () -> !ttg.memdesc<2x1xi64, #shared2, #smem, mutable>
80+
%c0_i32 = arith.constant 0 : i32
81+
// CHECK: [[TRUE:%.*]] = arith.constant true
82+
// CHECK: [[BARRIER_SLICE:%.*]] = ttg.memdesc_index
83+
// CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[BARRIER1]][[[BARRIER_PRED]]], [[BARRIER_SLICE]][[[TRUE]]]
84+
%accUse = arith.constant false
85+
%pred = arith.constant true
86+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_async} :
87+
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
88+
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
89+
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
90+
ttng.tc_gen5_commit %barrier, %barrierPred : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
91+
%barrier_slice = ttg.memdesc_index %barrier2[%c0_i32] : !ttg.memdesc<2x1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<1xi64, #shared2, #smem, mutable, 2x1>
92+
ttng.tc_gen5_commit %barrier_slice : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable, 2x1>
93+
94+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_async} :
95+
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
96+
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
97+
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
98+
99+
%random_pred = arith.cmpi eq, %barrierPred, %pred : i1
100+
scf.if %random_pred {
101+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_async} :
102+
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
103+
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
104+
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
105+
}
106+
// This commit should not be merged into any of two mma ops above
107+
// CHECK: tc_gen5_commit
108+
ttng.tc_gen5_commit %barrier, %barrierPred : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
109+
110+
// The mma predicate is not a constant true. The commit op should not be merged
111+
// CHECK: tc_gen5_commit
112+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %random_pred {is_async} :
113+
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
114+
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
115+
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
116+
ttng.tc_gen5_commit %barrier : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
117+
118+
// There is an impure op between mma and commit ops. Do not allow merging in such cases.
119+
// CHECK: tc_gen5_commit
120+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_async} :
121+
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
122+
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
123+
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
124+
ttng.wait_barrier %barrier, %c0_i32 : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
125+
ttng.tc_gen5_commit %barrier : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
126+
127+
tt.return
128+
}
129+
}

0 commit comments

Comments
 (0)