Skip to content

Commit ab4a29a

Browse files
authored
[BACKEND] Recommit tmem alloc hoisting out of if (#7605)
Change the flow to only apply the hoisting out of if after pipelining and warp specialization as this transformation may block muli-buffering.
1 parent 2f5a031 commit ab4a29a

File tree

7 files changed

+250
-44
lines changed

7 files changed

+250
-44
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def TritonGPUHoistTMEMAlloc : Pass<"tritongpu-hoist-tmem-alloc", "mlir::ModuleOp
6060
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
6161
"mlir::scf::SCFDialect",
6262
"mlir::arith::ArithDialect"];
63+
let options = [
64+
Option<"hoistOutOfIf", "hoist-out-of-if",
65+
"bool", /*default*/"false",
66+
"Hoist TMEM allocations out of if statements">
67+
];
6368
}
6469

6570
def TritonGPUTestPipelineLowerLoop : Pass<"tritongpu-test-pipeline-lower-loop", "mlir::ModuleOp"> {

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 176 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ using TMEMTokenLoadOp = HasToken<ttng::TMEMLoadOp>;
3838
using TMEMTokenStoreOp = HasToken<ttng::TMEMStoreOp>;
3939
using TMEMTokenAllocOp = HasToken<ttng::TMEMAllocOp>;
4040

41-
class CombineTMEMStoreAndSelect : public OpRewritePattern<TMEMTokenStoreOp> {
41+
class CombineTMEMStoreAndSelect : public OpRewritePattern<ttng::TMEMStoreOp> {
4242
public:
4343
using OpRewritePattern::OpRewritePattern;
4444

45-
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
45+
LogicalResult matchAndRewrite(ttng::TMEMStoreOp store,
4646
PatternRewriter &rewriter) const override {
47+
if (!store.getDep())
48+
return failure();
4749
Value src = store.getSrc();
4850
auto select = src.getDefiningOp<arith::SelectOp>();
4951
if (!select) {
@@ -79,12 +81,14 @@ class CombineTMEMStoreAndSelect : public OpRewritePattern<TMEMTokenStoreOp> {
7981
}
8082
};
8183

82-
class RemoveUnusedTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
84+
class RemoveUnusedTMEMLoad : public OpRewritePattern<ttng::TMEMLoadOp> {
8385
public:
8486
using OpRewritePattern::OpRewritePattern;
8587

86-
LogicalResult matchAndRewrite(TMEMTokenLoadOp load,
88+
LogicalResult matchAndRewrite(ttng::TMEMLoadOp load,
8789
PatternRewriter &rewriter) const override {
90+
if (!load.getDep())
91+
return failure();
8892
if (!load.getResult().use_empty())
8993
return failure();
9094
rewriter.replaceAllUsesWith(load.getToken(), load.getDep());
@@ -93,12 +97,14 @@ class RemoveUnusedTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
9397
};
9498

9599
// Load-store forwarding pattern.
96-
class CombineTMEMLoadAndStore : public OpRewritePattern<TMEMTokenStoreOp> {
100+
class CombineTMEMLoadAndStore : public OpRewritePattern<ttng::TMEMStoreOp> {
97101
public:
98102
using OpRewritePattern::OpRewritePattern;
99103

100-
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
104+
LogicalResult matchAndRewrite(ttng::TMEMStoreOp store,
101105
PatternRewriter &rewriter) const override {
106+
if (!store.getDep())
107+
return failure();
102108
auto load = store.getDep().getDefiningOp<HasToken<ttng::TMEMLoadOp>>();
103109
if (!load || load.getResult() != store.getSrc() ||
104110
load.getSrc() != store.getDst())
@@ -108,12 +114,14 @@ class CombineTMEMLoadAndStore : public OpRewritePattern<TMEMTokenStoreOp> {
108114
}
109115
};
110116

111-
class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
117+
class SinkTMEMLoad : public OpRewritePattern<ttng::TMEMLoadOp> {
112118
public:
113119
using OpRewritePattern::OpRewritePattern;
114120

115-
LogicalResult matchAndRewrite(TMEMTokenLoadOp load,
121+
LogicalResult matchAndRewrite(ttng::TMEMLoadOp load,
116122
PatternRewriter &rewriter) const override {
123+
if (!load.getDep())
124+
return failure();
117125
auto forOp = load->getParentOfType<scf::ForOp>();
118126
if (!forOp) {
119127
return failure();
@@ -148,14 +156,130 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148156
}
149157
};
150158

159+
// Combine back TMEM alloc and store. This is equivalent but gives us a more
160+
// canonical form to do further optimizations.
161+
class CombineTMEMStoreAndAlloc : public OpRewritePattern<ttng::TMEMStoreOp> {
162+
public:
163+
using OpRewritePattern::OpRewritePattern;
164+
165+
LogicalResult matchAndRewrite(ttng::TMEMStoreOp store,
166+
PatternRewriter &rewriter) const override {
167+
if (!store.getDep())
168+
return failure();
169+
if (!matchPattern(store.getPred(), m_One()))
170+
return failure();
171+
auto alloc = store.getDep().getDefiningOp<TMEMTokenAllocOp>();
172+
if (!alloc)
173+
return failure();
174+
if (store.getSrc() != alloc.getResult())
175+
return failure();
176+
if (alloc->getBlock() != store->getBlock())
177+
return failure();
178+
alloc.getSrcMutable().assign(store.getSrc());
179+
rewriter.replaceOp(store, alloc.getToken());
180+
return success();
181+
}
182+
};
183+
184+
// Hoists a tmem alloc outside an if op like this:
185+
// %0 = scf.if {
186+
// %1, %token0 = tmem.alloc %init
187+
// ...
188+
// %2 = tmem.load %1, %token1
189+
// scf.yield %2
190+
// } else {
191+
// scf.yield %init
192+
// }
193+
// ->
194+
// %a, %token0 = tmem.alloc %init
195+
// %token2 = scf.if {
196+
//
197+
// ...
198+
// scf.yield %token1
199+
// } else {
200+
// scf.yield %token0
201+
// }
202+
// %2 = tmem.load %a, %token2
203+
class HoistTMEMAllocOutOfIf : public OpRewritePattern<ttng::TMEMAllocOp> {
204+
public:
205+
using OpRewritePattern::OpRewritePattern;
206+
207+
LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc,
208+
PatternRewriter &rewriter) const override {
209+
if (!alloc.getToken())
210+
return failure();
211+
Value init = alloc.getSrc();
212+
if (!init)
213+
return failure();
214+
auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp());
215+
if (!ifOp || !ifOp.elseBlock())
216+
return failure();
217+
auto thenOp = ifOp.thenBlock()->getTerminator();
218+
auto elseOp = ifOp.elseBlock()->getTerminator();
219+
SmallVector<int> yieldArgs;
220+
for (auto [thenOperand, elseOperand] :
221+
llvm::zip(thenOp->getOpOperands(), elseOp->getOpOperands())) {
222+
auto load = thenOperand.get().getDefiningOp<TMEMTokenLoadOp>();
223+
if (!load || load.getSrc() != alloc.getResult())
224+
continue;
225+
if (elseOperand.get() != init)
226+
continue;
227+
yieldArgs.push_back(thenOperand.getOperandNumber());
228+
}
229+
if (yieldArgs.empty())
230+
return failure();
231+
// Since init is used in the else terminator we know that it dominates the
232+
// if op.
233+
alloc->moveBefore(ifOp);
234+
rewriter.setInsertionPointAfter(ifOp);
235+
for (int argNo : yieldArgs) {
236+
auto load =
237+
cast<TMEMTokenLoadOp>(thenOp->getOperand(argNo).getDefiningOp());
238+
auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone(*load));
239+
rewriter.modifyOpInPlace(ifOp, [&] {
240+
ifOp->getResult(argNo).replaceAllUsesWith(newLoad.getResult());
241+
newLoad.getDepMutable().assign(ifOp->getResult(argNo));
242+
thenOp->setOperand(argNo, load.getToken());
243+
elseOp->setOperand(argNo, alloc.getToken());
244+
ifOp->getResult(argNo).setType(newLoad.getToken().getType());
245+
});
246+
}
247+
return success();
248+
}
249+
};
250+
251+
// Forward a TMEM load into the user allocation.
252+
class TMEMLoadForwarding : public OpRewritePattern<ttng::TMEMAllocOp> {
253+
public:
254+
using OpRewritePattern::OpRewritePattern;
255+
256+
LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc,
257+
PatternRewriter &rewriter) const override {
258+
if (!alloc.getToken())
259+
return failure();
260+
Value init = alloc.getSrc();
261+
if (!init)
262+
return failure();
263+
auto load = init.getDefiningOp<TMEMTokenLoadOp>();
264+
if (!load || !load->hasOneUse() || !load.getDep().hasOneUse())
265+
return failure();
266+
if (alloc.getType() != load.getSrc().getType())
267+
return failure();
268+
rewriter.replaceOp(alloc, {load.getSrc(), load.getDep()});
269+
return success();
270+
}
271+
};
272+
151273
// Remove loop-carried tensor dependencies if they are fed immediately into a
152274
// TMEM store by pulling the store into the previous iteration.
153-
class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
275+
class RotateTMEMStoreInLoop : public OpRewritePattern<ttng::TMEMStoreOp> {
154276
public:
155277
using OpRewritePattern::OpRewritePattern;
156278

157-
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
279+
LogicalResult matchAndRewrite(ttng::TMEMStoreOp store,
158280
PatternRewriter &rewriter) const override {
281+
if (!store.getDep())
282+
return failure();
159283
// Pattern match stores whose source comes from a loop region argument and
160284
// whose predicate is loop-invariant.
161285
scf::ForOp forOp = dyn_cast<scf::ForOp>(store->getParentOp());
@@ -215,12 +339,14 @@ class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
215339

216340
// Remove loop-carried tensor dependencies if they are the result of TMEM loads
217341
// at the end of the loop by pushing the load into the next iteration.
218-
class RotateTMEMLoadInLoop : public OpRewritePattern<TMEMTokenLoadOp> {
342+
class RotateTMEMLoadInLoop : public OpRewritePattern<ttng::TMEMLoadOp> {
219343
public:
220344
using OpRewritePattern::OpRewritePattern;
221345

222-
LogicalResult matchAndRewrite(TMEMTokenLoadOp load,
346+
LogicalResult matchAndRewrite(ttng::TMEMLoadOp load,
223347
PatternRewriter &rewriter) const override {
348+
if (!load.getDep())
349+
return failure();
224350
// Pattern match loads whose results are only passed into the next iteration
225351
// of a loop.
226352
scf::ForOp forOp = dyn_cast<scf::ForOp>(load->getParentOp());
@@ -391,32 +517,55 @@ struct HoistTMEMAlloc
391517

392518
void runOnOperation() override {
393519
ModuleOp m = getOperation();
394-
SmallVector<ttng::MMAv5OpInterface> mmaOps;
395-
m.walk([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back(mmaOp); });
396-
for (auto mmaOp : mmaOps) {
397-
auto forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp());
398-
if (!forOp) {
399-
continue;
520+
if (!hoistOutOfIf) {
521+
SmallVector<ttng::MMAv5OpInterface> mmaOps;
522+
m.walk([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back(mmaOp); });
523+
for (auto mmaOp : mmaOps) {
524+
auto forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp());
525+
if (!forOp) {
526+
continue;
527+
}
528+
hoistInvariantInputs(mmaOp, forOp);
529+
530+
// Only hoist the TMEM alloc feeding into the accumulator. Leave the
531+
// ones for the scales in the loop.
532+
auto alloc = mmaOp.getAccumulator().getDefiningOp<TMEMTokenAllocOp>();
533+
if (!alloc || alloc->getParentRegion() != mmaOp->getParentRegion()) {
534+
continue;
535+
}
536+
hoistTMEMAlloc(alloc, forOp);
400537
}
401-
hoistInvariantInputs(mmaOp, forOp);
402-
403-
// Only hoist the TMEM alloc feeding into the accumulator. Leave the ones
404-
// for the scales in the loop.
405-
auto alloc = mmaOp.getAccumulator().getDefiningOp<TMEMTokenAllocOp>();
406-
if (!alloc || alloc->getParentRegion() != mmaOp->getParentRegion()) {
407-
continue;
408-
}
409-
hoistTMEMAlloc(alloc, forOp);
410538
}
411539

412540
mlir::RewritePatternSet patterns(&getContext());
413541
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
414542
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
415543
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
544+
if (hoistOutOfIf) {
545+
patterns.add<CombineTMEMStoreAndAlloc, HoistTMEMAllocOutOfIf,
546+
TMEMLoadForwarding>(&getContext());
547+
}
416548
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
417549
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
418550
llvm_unreachable("Failed to hoist tmem_store");
419551
}
552+
553+
// TODO: currently some code assumes that a mutable tmem alloc doesn't have
554+
// an initial value. As a workaround we break up the op in order to keep
555+
// this form for the downstream passes. We should remove this once the
556+
// downstread passes are fixed.
557+
m.walk([&](ttng::TMEMAllocOp alloc) {
558+
if (alloc.getType().getMutableMemory() && alloc.getSrc()) {
559+
OpBuilder builder(alloc);
560+
builder.setInsertionPointAfter(alloc);
561+
auto store = builder.create<ttng::TMEMStoreOp>(
562+
alloc.getLoc(), builder.getType<AsyncTokenType>(),
563+
alloc.getResult(), alloc.getToken(), alloc.getSrc(),
564+
builder.create<arith::ConstantIntOp>(alloc.getLoc(), 1, 1));
565+
alloc.getToken().replaceAllUsesExcept(store.getToken(), store);
566+
alloc.getSrcMutable().clear();
567+
}
568+
});
420569
}
421570
};
422571

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,11 +894,14 @@ void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
894894
llvm::to_vector(alloc.getResult().getUsers());
895895
Value replTok = OpBuilder(forOp).create<ub::PoisonOp>(
896896
forOp.getLoc(), builder.getType<AsyncTokenType>());
897+
if (newAlloc.getToken()) {
898+
newAlloc.getToken().replaceAllUsesWith(replTok);
899+
}
897900
for (auto user : allocUsers) {
898901
if (auto store = dyn_cast<ttng::TMEMStoreOp>(user)) {
902+
store.getDepMutable().clear();
903+
store.getToken().replaceAllUsesWith(replTok);
899904
if (forOp->isAncestor(store)) {
900-
store.getDepMutable().clear();
901-
store.getToken().replaceAllUsesWith(replTok);
902905
// We can multibuffer, since the store is a point where we can
903906
// change the buffer index
904907
multibufferingIsValid = true;
@@ -926,9 +929,9 @@ void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
926929
store.getDstMutable().assign(tmemSlice);
927930
}
928931
} else if (auto load = dyn_cast<ttng::TMEMLoadOp>(user)) {
932+
load.getDepMutable().clear();
933+
load.getToken().replaceAllUsesWith(replTok);
929934
if (forOp->isAncestor(load)) {
930-
load.getDepMutable().clear();
931-
load.getToken().replaceAllUsesWith(replTok);
932935
builder.setStageCluster(schedule[load]);
933936
builder.setInsertionPoint(load);
934937
Value curBufIdx = getCurrBufIdx(load);

python/src/passes.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
5959
ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce);
6060
ADD_PASS_WRAPPER_0("add_optimize_thread_locality",
6161
createTritonGPUOptimizeThreadLocality);
62-
ADD_PASS_WRAPPER_0("add_hoist_tmem_alloc", createTritonGPUHoistTMEMAlloc);
62+
ADD_PASS_OPTION_WRAPPER_1("add_hoist_tmem_alloc",
63+
createTritonGPUHoistTMEMAlloc, bool);
6364
ADD_PASS_OPTION_WRAPPER_1("add_assign_latencies",
6465
createTritonGPUAssignLatencies, int);
6566
ADD_PASS_WRAPPER_0("add_schedule_loops", createTritonGPUScheduleLoops);

test/TritonGPU/hoist-tmem-alloc.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc -canonicalize | FileCheck %s
2+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-hoist-tmem-alloc="hoist-out-of-if=true" -canonicalize | FileCheck %s -check-prefix=HOIST-IF
23

34
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
45
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -307,3 +308,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
307308
tt.return %res_f16 : tensor<128x128xf16, #blocked>
308309
}
309310
}
311+
312+
// -----
313+
314+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
315+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
316+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
317+
// HOIST-IF-LABEL: @hoist_out_of_if
318+
tt.func public @hoist_out_of_if(%arg0: i1, %arg1: tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked> {
319+
// HOIST-IF: %[[A:.+]], %[[T0:.+]] = ttng.tmem_alloc : ()
320+
// HOIST-IF: %[[T1:.+]] = ttng.tmem_store %{{.*}}, %[[A]][%[[T0]]]
321+
// HOIST-IF: %[[I:.+]] = scf.if %{{.+}} -> (!ttg.async.token) {
322+
// HOIST-IF: %[[T2:.+]] = "write_to_tmem"
323+
// HOIST-IF: scf.yield %[[T2]]
324+
// HOIST-IF: } else {
325+
// HOIST-IF: scf.yield %[[T1]]
326+
// HOIST-IF: }
327+
// HOIST-IF: %[[L:.+]], %[[T4:.+]] = ttng.tmem_load %[[A]][%[[I]]
328+
// HOIST-IF: tt.return %[[L]]
329+
%0 = scf.if %arg0 -> (tensor<128x128xf32, #blocked>) {
330+
%result, %token = ttng.tmem_alloc %arg1 : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
331+
%1 = "write_to_tmem"(%result) : (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) -> !ttg.async.token
332+
%result_0, %token_1 = ttng.tmem_load %result[%1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
333+
scf.yield %result_0 : tensor<128x128xf32, #blocked>
334+
} else {
335+
scf.yield %arg1 : tensor<128x128xf32, #blocked>
336+
}
337+
tt.return %0 : tensor<128x128xf32, #blocked>
338+
}
339+
}
340+
341+
// -----
342+
343+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
344+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
345+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
346+
tt.func public @forward_tmem_load(%m: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, %t: !ttg.async.token) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) {
347+
%result, %token0 = ttng.tmem_load %m[%t] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
348+
// HOIST-IF-LABEL: @forward_tmem_load
349+
// HOIST-IF-SAME: %[[ARG0:.+]]: !ttg.memdesc<128x128xf32,
350+
// HOIST-IF-SAME: %[[ARG1:.+]]: !ttg.async.token
351+
// HOIST-IF-NEXT: tt.return %[[ARG0]], %[[ARG1]]
352+
%result1, %token1 = ttng.tmem_alloc %result : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
353+
tt.return %result1, %token1 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
354+
}
355+
}

0 commit comments

Comments
 (0)