@@ -148,6 +148,94 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148
148
}
149
149
};
150
150
151
+ // Combine back TMEM alloc and store. This is equivalent but gives us a more
152
+ // canonical form to do further optimizations.
153
+ class CombineTMEMStoreAndAlloc : public OpRewritePattern <TMEMTokenStoreOp> {
154
+ public:
155
+ using OpRewritePattern::OpRewritePattern;
156
+
157
+ LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
158
+ PatternRewriter &rewriter) const override {
159
+ if (!matchPattern (store.getPred (), m_One ()))
160
+ return failure ();
161
+ auto alloc = store.getDep ().getDefiningOp <TMEMTokenAllocOp>();
162
+ if (!alloc)
163
+ return failure ();
164
+ if (alloc->getBlock () != store->getBlock ())
165
+ return failure ();
166
+ alloc.getSrcMutable ().assign (store.getSrc ());
167
+ rewriter.replaceOp (store, alloc.getToken ());
168
+ return success ();
169
+ }
170
+ };
171
+
172
+ // Hoists a tmem alloc outside an if op like this:
173
+ // %0 = scf.if {
174
+ // %1, %token0 = tmem.alloc %init
175
+ // ...
176
+ // %2 = tmem.load %1, %token1
177
+ // scf.yield %2
178
+ // } else {
179
+ // scf.yield %init
180
+ // }
181
+ // ->
182
+ // %a, %token0 = tmem.alloc %init
183
+ // %token2 = scf.if {
184
+ //
185
+ // ...
186
+ // scf.yield %token1
187
+ // } else {
188
+ // scf.yield %token0
189
+ // }
190
+ // %2 = tmem.load %a, %token2
191
+ class HoistTMEMAllocOutOfIf : public OpRewritePattern <ttng::TMEMAllocOp> {
192
+ public:
193
+ using OpRewritePattern::OpRewritePattern;
194
+
195
+ LogicalResult matchAndRewrite (ttng::TMEMAllocOp alloc,
196
+ PatternRewriter &rewriter) const override {
197
+ if (!alloc.getToken ())
198
+ return failure ();
199
+ Value init = alloc.getSrc ();
200
+ if (!init)
201
+ return failure ();
202
+ auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp ());
203
+ if (!ifOp)
204
+ return failure ();
205
+ auto thenOp = ifOp.thenBlock ()->getTerminator ();
206
+ auto elseOp = ifOp.elseBlock ()->getTerminator ();
207
+ SmallVector<int > yieldArgs;
208
+ for (auto [thenOperand, elseOperand] :
209
+ llvm::zip (thenOp->getOpOperands (), elseOp->getOpOperands ())) {
210
+ auto load = thenOperand.get ().getDefiningOp <TMEMTokenLoadOp>();
211
+ if (!load || load.getSrc () != alloc.getResult ())
212
+ continue ;
213
+ if (elseOperand.get () != init)
214
+ continue ;
215
+ yieldArgs.push_back (thenOperand.getOperandNumber ());
216
+ }
217
+ if (yieldArgs.empty ())
218
+ return failure ();
219
+ // Since init is used in the else terminator we know that it dominates the
220
+ // if op.
221
+ alloc->moveBefore (ifOp);
222
+ rewriter.setInsertionPointAfter (ifOp);
223
+ for (int argNo : yieldArgs) {
224
+ auto load =
225
+ cast<TMEMTokenLoadOp>(thenOp->getOperand (argNo).getDefiningOp ());
226
+ auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone (*load));
227
+ rewriter.modifyOpInPlace (ifOp, [&] {
228
+ ifOp->getResult (argNo).replaceAllUsesWith (newLoad.getResult ());
229
+ newLoad.getDepMutable ().assign (ifOp->getResult (argNo));
230
+ thenOp->setOperand (argNo, load.getToken ());
231
+ elseOp->setOperand (argNo, alloc.getToken ());
232
+ ifOp->getResult (argNo).setType (newLoad.getToken ().getType ());
233
+ });
234
+ }
235
+ return success ();
236
+ }
237
+ };
238
+
151
239
// Remove loop-carried tensor dependencies if they are fed immediately into a
152
240
// TMEM store by pulling the store into the previous iteration.
153
241
class RotateTMEMStoreInLoop : public OpRewritePattern <TMEMTokenStoreOp> {
@@ -412,11 +500,29 @@ struct HoistTMEMAlloc
412
500
mlir::RewritePatternSet patterns (&getContext ());
413
501
patterns.add <RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
414
502
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
415
- SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext ());
503
+ SinkTMEMLoad, RemoveUnusedTMEMLoad, CombineTMEMStoreAndAlloc,
504
+ HoistTMEMAllocOutOfIf>(&getContext ());
416
505
scf::ForOp::getCanonicalizationPatterns (patterns, &getContext ());
417
506
if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
418
507
llvm_unreachable (" Failed to hoist tmem_store" );
419
508
}
509
+
510
+ // TODO: currently some code assumes that a mutable tmem alloc doesn't have
511
+ // an initial value. As a workaround we break up the op in order to keep
512
+ // this form for the downstream passes. We should remove this once the
513
+ // downstread passes are fixed.
514
+ m.walk ([&](ttng::TMEMAllocOp alloc) {
515
+ if (alloc.getType ().getMutableMemory () && alloc.getSrc ()) {
516
+ OpBuilder builder (alloc);
517
+ builder.setInsertionPointAfter (alloc);
518
+ auto store = builder.create <ttng::TMEMStoreOp>(
519
+ alloc.getLoc (), builder.getType <AsyncTokenType>(),
520
+ alloc.getResult (), alloc.getToken (), alloc.getSrc (),
521
+ builder.create <arith::ConstantIntOp>(alloc.getLoc (), 1 , 1 ));
522
+ alloc.getToken ().replaceAllUsesExcept (store.getToken (), store);
523
+ alloc.getSrcMutable ().clear ();
524
+ }
525
+ });
420
526
}
421
527
};
422
528
0 commit comments