Skip to content

Commit d0642aa

Browse files
authored
Revert "[BACKEND] Hoist tmem alloc outside of if (#7568)" (#7597)
This reverts commit 96e91d4.
1 parent e6aa86c commit d0642aa

File tree

2 files changed

+1
-136
lines changed

2 files changed

+1
-136
lines changed

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -148,94 +148,6 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148148
}
149149
};
150150

151-
// Combine back TMEM alloc and store. This is equivalent but gives us a more
152-
// canonical form to do further optimizations.
153-
class CombineTMEMStoreAndAlloc : public OpRewritePattern<TMEMTokenStoreOp> {
154-
public:
155-
using OpRewritePattern::OpRewritePattern;
156-
157-
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
158-
PatternRewriter &rewriter) const override {
159-
if (!matchPattern(store.getPred(), m_One()))
160-
return failure();
161-
auto alloc = store.getDep().getDefiningOp<TMEMTokenAllocOp>();
162-
if (!alloc)
163-
return failure();
164-
if (alloc->getBlock() != store->getBlock())
165-
return failure();
166-
alloc.getSrcMutable().assign(store.getSrc());
167-
rewriter.replaceOp(store, alloc.getToken());
168-
return success();
169-
}
170-
};
171-
172-
// Hoists a tmem alloc outside an if op like this:
173-
// %0 = scf.if {
174-
// %1, %token0 = tmem.alloc %init
175-
// ...
176-
// %2 = tmem.load %1, %token1
177-
// scf.yield %2
178-
// } else {
179-
// scf.yield %init
180-
// }
181-
// ->
182-
// %a, %token0 = tmem.alloc %init
183-
// %token2 = scf.if {
184-
//
185-
// ...
186-
// scf.yield %token1
187-
// } else {
188-
// scf.yield %token0
189-
// }
190-
// %2 = tmem.load %a, %token2
191-
class HoistTMEMAllocOutOfIf : public OpRewritePattern<ttng::TMEMAllocOp> {
192-
public:
193-
using OpRewritePattern::OpRewritePattern;
194-
195-
LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc,
196-
PatternRewriter &rewriter) const override {
197-
if (!alloc.getToken())
198-
return failure();
199-
Value init = alloc.getSrc();
200-
if (!init)
201-
return failure();
202-
auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp());
203-
if (!ifOp)
204-
return failure();
205-
auto thenOp = ifOp.thenBlock()->getTerminator();
206-
auto elseOp = ifOp.elseBlock()->getTerminator();
207-
SmallVector<int> yieldArgs;
208-
for (auto [thenOperand, elseOperand] :
209-
llvm::zip(thenOp->getOpOperands(), elseOp->getOpOperands())) {
210-
auto load = thenOperand.get().getDefiningOp<TMEMTokenLoadOp>();
211-
if (!load || load.getSrc() != alloc.getResult())
212-
continue;
213-
if (elseOperand.get() != init)
214-
continue;
215-
yieldArgs.push_back(thenOperand.getOperandNumber());
216-
}
217-
if (yieldArgs.empty())
218-
return failure();
219-
// Since init is used in the else terminator we know that it dominates the
220-
// if op.
221-
alloc->moveBefore(ifOp);
222-
rewriter.setInsertionPointAfter(ifOp);
223-
for (int argNo : yieldArgs) {
224-
auto load =
225-
cast<TMEMTokenLoadOp>(thenOp->getOperand(argNo).getDefiningOp());
226-
auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone(*load));
227-
rewriter.modifyOpInPlace(ifOp, [&] {
228-
ifOp->getResult(argNo).replaceAllUsesWith(newLoad.getResult());
229-
newLoad.getDepMutable().assign(ifOp->getResult(argNo));
230-
thenOp->setOperand(argNo, load.getToken());
231-
elseOp->setOperand(argNo, alloc.getToken());
232-
ifOp->getResult(argNo).setType(newLoad.getToken().getType());
233-
});
234-
}
235-
return success();
236-
}
237-
};
238-
239151
// Remove loop-carried tensor dependencies if they are fed immediately into a
240152
// TMEM store by pulling the store into the previous iteration.
241153
class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
@@ -500,29 +412,11 @@ struct HoistTMEMAlloc
500412
mlir::RewritePatternSet patterns(&getContext());
501413
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
502414
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
503-
SinkTMEMLoad, RemoveUnusedTMEMLoad, CombineTMEMStoreAndAlloc,
504-
HoistTMEMAllocOutOfIf>(&getContext());
415+
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
505416
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
506417
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
507418
llvm_unreachable("Failed to hoist tmem_store");
508419
}
509-
510-
// TODO: currently some code assumes that a mutable tmem alloc doesn't have
511-
// an initial value. As a workaround we break up the op in order to keep
512-
// this form for the downstream passes. We should remove this once the
513-
// downstread passes are fixed.
514-
m.walk([&](ttng::TMEMAllocOp alloc) {
515-
if (alloc.getType().getMutableMemory() && alloc.getSrc()) {
516-
OpBuilder builder(alloc);
517-
builder.setInsertionPointAfter(alloc);
518-
auto store = builder.create<ttng::TMEMStoreOp>(
519-
alloc.getLoc(), builder.getType<AsyncTokenType>(),
520-
alloc.getResult(), alloc.getToken(), alloc.getSrc(),
521-
builder.create<arith::ConstantIntOp>(alloc.getLoc(), 1, 1));
522-
alloc.getToken().replaceAllUsesExcept(store.getToken(), store);
523-
alloc.getSrcMutable().clear();
524-
}
525-
});
526420
}
527421
};
528422

test/TritonGPU/hoist-tmem-alloc.mlir

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -307,32 +307,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
307307
tt.return %res_f16 : tensor<128x128xf16, #blocked>
308308
}
309309
}
310-
311-
// -----
312-
313-
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
314-
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
315-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
316-
// CHECK-LABEL: @hoist_out_of_if
317-
tt.func public @hoist_out_of_if(%arg0: i1, %arg1: tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked> {
318-
// CHECK: %[[A:.+]], %[[T0:.+]] = ttng.tmem_alloc : ()
319-
// CHECK: %[[T1:.+]] = ttng.tmem_store %{{.*}}, %[[A]][%[[T0]]]
320-
// CHECK: %[[I:.+]] = scf.if %{{.+}} -> (!ttg.async.token) {
321-
// CHECK: %[[T2:.+]] = "write_to_tmem"
322-
// CHECK: scf.yield %[[T2]]
323-
// CHECK: } else {
324-
// CHECK: scf.yield %[[T1]]
325-
// CHECK: }
326-
// CHECK: %[[L:.+]], %[[T4:.+]] = ttng.tmem_load %[[A]][%[[I]]
327-
// CHECK: tt.return %[[L]]
328-
%0 = scf.if %arg0 -> (tensor<128x128xf32, #blocked>) {
329-
%result, %token = ttng.tmem_alloc %arg1 : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
330-
%1 = "write_to_tmem"(%result) : (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) -> !ttg.async.token
331-
%result_0, %token_1 = ttng.tmem_load %result[%1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
332-
scf.yield %result_0 : tensor<128x128xf32, #blocked>
333-
} else {
334-
scf.yield %arg1 : tensor<128x128xf32, #blocked>
335-
}
336-
tt.return %0 : tensor<128x128xf32, #blocked>
337-
}
338-
}

0 commit comments

Comments
 (0)