Skip to content

Commit 96e91d4

Browse files
authored
[BACKEND] Hoist tmem alloc outside of if (#7568)
In case a tmem alloc is used outside an if hoisting the alloc allows to avoid creating large amount of registers created at the end of a block.
1 parent 7affb3b commit 96e91d4

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,94 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148148
}
149149
};
150150

151+
// Combine back TMEM alloc and store. This is equivalent but gives us a more
152+
// canonical form to do further optimizations.
153+
class CombineTMEMStoreAndAlloc : public OpRewritePattern<TMEMTokenStoreOp> {
154+
public:
155+
using OpRewritePattern::OpRewritePattern;
156+
157+
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
158+
PatternRewriter &rewriter) const override {
159+
if (!matchPattern(store.getPred(), m_One()))
160+
return failure();
161+
auto alloc = store.getDep().getDefiningOp<TMEMTokenAllocOp>();
162+
if (!alloc)
163+
return failure();
164+
if (alloc->getBlock() != store->getBlock())
165+
return failure();
166+
alloc.getSrcMutable().assign(store.getSrc());
167+
rewriter.replaceOp(store, alloc.getToken());
168+
return success();
169+
}
170+
};
171+
172+
// Hoists a tmem alloc outside an if op like this:
173+
// %0 = scf.if {
174+
// %1, %token0 = tmem.alloc %init
175+
// ...
176+
// %2 = tmem.load %1, %token1
177+
// scf.yield %2
178+
// } else {
179+
// scf.yield %init
180+
// }
181+
// ->
182+
// %a, %token0 = tmem.alloc %init
183+
// %token2 = scf.if {
184+
//
185+
// ...
186+
// scf.yield %token1
187+
// } else {
188+
// scf.yield %token0
189+
// }
190+
// %2 = tmem.load %a, %token2
191+
class HoistTMEMAllocOutOfIf : public OpRewritePattern<ttng::TMEMAllocOp> {
192+
public:
193+
using OpRewritePattern::OpRewritePattern;
194+
195+
LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc,
196+
PatternRewriter &rewriter) const override {
197+
if (!alloc.getToken())
198+
return failure();
199+
Value init = alloc.getSrc();
200+
if (!init)
201+
return failure();
202+
auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp());
203+
if (!ifOp)
204+
return failure();
205+
auto thenOp = ifOp.thenBlock()->getTerminator();
206+
auto elseOp = ifOp.elseBlock()->getTerminator();
207+
SmallVector<int> yieldArgs;
208+
for (auto [thenOperand, elseOperand] :
209+
llvm::zip(thenOp->getOpOperands(), elseOp->getOpOperands())) {
210+
auto load = thenOperand.get().getDefiningOp<TMEMTokenLoadOp>();
211+
if (!load || load.getSrc() != alloc.getResult())
212+
continue;
213+
if (elseOperand.get() != init)
214+
continue;
215+
yieldArgs.push_back(thenOperand.getOperandNumber());
216+
}
217+
if (yieldArgs.empty())
218+
return failure();
219+
// Since init is used in the else terminator we know that it dominates the
220+
// if op.
221+
alloc->moveBefore(ifOp);
222+
rewriter.setInsertionPointAfter(ifOp);
223+
for (int argNo : yieldArgs) {
224+
auto load =
225+
cast<TMEMTokenLoadOp>(thenOp->getOperand(argNo).getDefiningOp());
226+
auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone(*load));
227+
rewriter.modifyOpInPlace(ifOp, [&] {
228+
ifOp->getResult(argNo).replaceAllUsesWith(newLoad.getResult());
229+
newLoad.getDepMutable().assign(ifOp->getResult(argNo));
230+
thenOp->setOperand(argNo, load.getToken());
231+
elseOp->setOperand(argNo, alloc.getToken());
232+
ifOp->getResult(argNo).setType(newLoad.getToken().getType());
233+
});
234+
}
235+
return success();
236+
}
237+
};
238+
151239
// Remove loop-carried tensor dependencies if they are fed immediately into a
152240
// TMEM store by pulling the store into the previous iteration.
153241
class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
@@ -412,11 +500,29 @@ struct HoistTMEMAlloc
412500
mlir::RewritePatternSet patterns(&getContext());
413501
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
414502
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
415-
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
503+
SinkTMEMLoad, RemoveUnusedTMEMLoad, CombineTMEMStoreAndAlloc,
504+
HoistTMEMAllocOutOfIf>(&getContext());
416505
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
417506
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
418507
llvm_unreachable("Failed to hoist tmem_store");
419508
}
509+
510+
// TODO: currently some code assumes that a mutable tmem alloc doesn't have
511+
// an initial value. As a workaround we break up the op in order to keep
512+
// this form for the downstream passes. We should remove this once the
513+
// downstread passes are fixed.
514+
m.walk([&](ttng::TMEMAllocOp alloc) {
515+
if (alloc.getType().getMutableMemory() && alloc.getSrc()) {
516+
OpBuilder builder(alloc);
517+
builder.setInsertionPointAfter(alloc);
518+
auto store = builder.create<ttng::TMEMStoreOp>(
519+
alloc.getLoc(), builder.getType<AsyncTokenType>(),
520+
alloc.getResult(), alloc.getToken(), alloc.getSrc(),
521+
builder.create<arith::ConstantIntOp>(alloc.getLoc(), 1, 1));
522+
alloc.getToken().replaceAllUsesExcept(store.getToken(), store);
523+
alloc.getSrcMutable().clear();
524+
}
525+
});
420526
}
421527
};
422528

test/TritonGPU/hoist-tmem-alloc.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
307307
tt.return %res_f16 : tensor<128x128xf16, #blocked>
308308
}
309309
}
310+
311+
// -----
312+
313+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
314+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
315+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
316+
// CHECK-LABEL: @hoist_out_of_if
317+
tt.func public @hoist_out_of_if(%arg0: i1, %arg1: tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked> {
318+
// CHECK: %[[A:.+]], %[[T0:.+]] = ttng.tmem_alloc : ()
319+
// CHECK: %[[T1:.+]] = ttng.tmem_store %{{.*}}, %[[A]][%[[T0]]]
320+
// CHECK: %[[I:.+]] = scf.if %{{.+}} -> (!ttg.async.token) {
321+
// CHECK: %[[T2:.+]] = "write_to_tmem"
322+
// CHECK: scf.yield %[[T2]]
323+
// CHECK: } else {
324+
// CHECK: scf.yield %[[T1]]
325+
// CHECK: }
326+
// CHECK: %[[L:.+]], %[[T4:.+]] = ttng.tmem_load %[[A]][%[[I]]
327+
// CHECK: tt.return %[[L]]
328+
%0 = scf.if %arg0 -> (tensor<128x128xf32, #blocked>) {
329+
%result, %token = ttng.tmem_alloc %arg1 : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
330+
%1 = "write_to_tmem"(%result) : (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) -> !ttg.async.token
331+
%result_0, %token_1 = ttng.tmem_load %result[%1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
332+
scf.yield %result_0 : tensor<128x128xf32, #blocked>
333+
} else {
334+
scf.yield %arg1 : tensor<128x128xf32, #blocked>
335+
}
336+
tt.return %0 : tensor<128x128xf32, #blocked>
337+
}
338+
}

0 commit comments

Comments
 (0)