@@ -38,12 +38,14 @@ using TMEMTokenLoadOp = HasToken<ttng::TMEMLoadOp>;
3838using TMEMTokenStoreOp = HasToken<ttng::TMEMStoreOp>;
3939using TMEMTokenAllocOp = HasToken<ttng::TMEMAllocOp>;
4040
41- class CombineTMEMStoreAndSelect : public OpRewritePattern <TMEMTokenStoreOp > {
41+ class CombineTMEMStoreAndSelect : public OpRewritePattern <ttng::TMEMStoreOp > {
4242public:
4343 using OpRewritePattern::OpRewritePattern;
4444
45- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
45+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
4646 PatternRewriter &rewriter) const override {
47+ if (!store.getDep ())
48+ return failure ();
4749 Value src = store.getSrc ();
4850 auto select = src.getDefiningOp <arith::SelectOp>();
4951 if (!select) {
@@ -79,12 +81,14 @@ class CombineTMEMStoreAndSelect : public OpRewritePattern<TMEMTokenStoreOp> {
7981 }
8082};
8183
82- class RemoveUnusedTMEMLoad : public OpRewritePattern <TMEMTokenLoadOp > {
84+ class RemoveUnusedTMEMLoad : public OpRewritePattern <ttng::TMEMLoadOp > {
8385public:
8486 using OpRewritePattern::OpRewritePattern;
8587
86- LogicalResult matchAndRewrite (TMEMTokenLoadOp load,
88+ LogicalResult matchAndRewrite (ttng::TMEMLoadOp load,
8789 PatternRewriter &rewriter) const override {
90+ if (!load.getDep ())
91+ return failure ();
8892 if (!load.getResult ().use_empty ())
8993 return failure ();
9094 rewriter.replaceAllUsesWith (load.getToken (), load.getDep ());
@@ -93,12 +97,14 @@ class RemoveUnusedTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
9397};
9498
9599// Load-store forwarding pattern.
96- class CombineTMEMLoadAndStore : public OpRewritePattern <TMEMTokenStoreOp > {
100+ class CombineTMEMLoadAndStore : public OpRewritePattern <ttng::TMEMStoreOp > {
97101public:
98102 using OpRewritePattern::OpRewritePattern;
99103
100- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
104+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
101105 PatternRewriter &rewriter) const override {
106+ if (!store.getDep ())
107+ return failure ();
102108 auto load = store.getDep ().getDefiningOp <HasToken<ttng::TMEMLoadOp>>();
103109 if (!load || load.getResult () != store.getSrc () ||
104110 load.getSrc () != store.getDst ())
@@ -108,12 +114,14 @@ class CombineTMEMLoadAndStore : public OpRewritePattern<TMEMTokenStoreOp> {
108114 }
109115};
110116
111- class SinkTMEMLoad : public OpRewritePattern <TMEMTokenLoadOp > {
117+ class SinkTMEMLoad : public OpRewritePattern <ttng::TMEMLoadOp > {
112118public:
113119 using OpRewritePattern::OpRewritePattern;
114120
115- LogicalResult matchAndRewrite (TMEMTokenLoadOp load,
121+ LogicalResult matchAndRewrite (ttng::TMEMLoadOp load,
116122 PatternRewriter &rewriter) const override {
123+ if (!load.getDep ())
124+ return failure ();
117125 auto forOp = load->getParentOfType <scf::ForOp>();
118126 if (!forOp) {
119127 return failure ();
@@ -148,14 +156,130 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148156 }
149157};
150158
159+ // Combine back TMEM alloc and store. This is equivalent but gives us a more
160+ // canonical form to do further optimizations.
161+ class CombineTMEMStoreAndAlloc : public OpRewritePattern <ttng::TMEMStoreOp> {
162+ public:
163+ using OpRewritePattern::OpRewritePattern;
164+
165+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
166+ PatternRewriter &rewriter) const override {
167+ if (!store.getDep ())
168+ return failure ();
169+ if (!matchPattern (store.getPred (), m_One ()))
170+ return failure ();
171+ auto alloc = store.getDep ().getDefiningOp <TMEMTokenAllocOp>();
172+ if (!alloc)
173+ return failure ();
174+ if (store.getSrc () != alloc.getResult ())
175+ return failure ();
176+ if (alloc->getBlock () != store->getBlock ())
177+ return failure ();
178+ alloc.getSrcMutable ().assign (store.getSrc ());
179+ rewriter.replaceOp (store, alloc.getToken ());
180+ return success ();
181+ }
182+ };
183+
184+ // Hoists a tmem alloc outside an if op like this:
185+ // %0 = scf.if {
186+ // %1, %token0 = tmem.alloc %init
187+ // ...
188+ // %2 = tmem.load %1, %token1
189+ // scf.yield %2
190+ // } else {
191+ // scf.yield %init
192+ // }
193+ // ->
194+ // %a, %token0 = tmem.alloc %init
195+ // %token2 = scf.if {
196+ //
197+ // ...
198+ // scf.yield %token1
199+ // } else {
200+ // scf.yield %token0
201+ // }
202+ // %2 = tmem.load %a, %token2
203+ class HoistTMEMAllocOutOfIf : public OpRewritePattern <ttng::TMEMAllocOp> {
204+ public:
205+ using OpRewritePattern::OpRewritePattern;
206+
207+ LogicalResult matchAndRewrite (ttng::TMEMAllocOp alloc,
208+ PatternRewriter &rewriter) const override {
209+ if (!alloc.getToken ())
210+ return failure ();
211+ Value init = alloc.getSrc ();
212+ if (!init)
213+ return failure ();
214+ auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp ());
215+ if (!ifOp || !ifOp.elseBlock ())
216+ return failure ();
217+ auto thenOp = ifOp.thenBlock ()->getTerminator ();
218+ auto elseOp = ifOp.elseBlock ()->getTerminator ();
219+ SmallVector<int > yieldArgs;
220+ for (auto [thenOperand, elseOperand] :
221+ llvm::zip (thenOp->getOpOperands (), elseOp->getOpOperands ())) {
222+ auto load = thenOperand.get ().getDefiningOp <TMEMTokenLoadOp>();
223+ if (!load || load.getSrc () != alloc.getResult ())
224+ continue ;
225+ if (elseOperand.get () != init)
226+ continue ;
227+ yieldArgs.push_back (thenOperand.getOperandNumber ());
228+ }
229+ if (yieldArgs.empty ())
230+ return failure ();
231+ // Since init is used in the else terminator we know that it dominates the
232+ // if op.
233+ alloc->moveBefore (ifOp);
234+ rewriter.setInsertionPointAfter (ifOp);
235+ for (int argNo : yieldArgs) {
236+ auto load =
237+ cast<TMEMTokenLoadOp>(thenOp->getOperand (argNo).getDefiningOp ());
238+ auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone (*load));
239+ rewriter.modifyOpInPlace (ifOp, [&] {
240+ ifOp->getResult (argNo).replaceAllUsesWith (newLoad.getResult ());
241+ newLoad.getDepMutable ().assign (ifOp->getResult (argNo));
242+ thenOp->setOperand (argNo, load.getToken ());
243+ elseOp->setOperand (argNo, alloc.getToken ());
244+ ifOp->getResult (argNo).setType (newLoad.getToken ().getType ());
245+ });
246+ }
247+ return success ();
248+ }
249+ };
250+
251+ // Forward a TMEM load into the user allocation.
252+ class TMEMLoadForwarding : public OpRewritePattern <ttng::TMEMAllocOp> {
253+ public:
254+ using OpRewritePattern::OpRewritePattern;
255+
256+ LogicalResult matchAndRewrite (ttng::TMEMAllocOp alloc,
257+ PatternRewriter &rewriter) const override {
258+ if (!alloc.getToken ())
259+ return failure ();
260+ Value init = alloc.getSrc ();
261+ if (!init)
262+ return failure ();
263+ auto load = init.getDefiningOp <TMEMTokenLoadOp>();
264+ if (!load || !load->hasOneUse () || !load.getDep ().hasOneUse ())
265+ return failure ();
266+ if (alloc.getType () != load.getSrc ().getType ())
267+ return failure ();
268+ rewriter.replaceOp (alloc, {load.getSrc (), load.getDep ()});
269+ return success ();
270+ }
271+ };
272+
151273// Remove loop-carried tensor dependencies if they are fed immediately into a
152274// TMEM store by pulling the store into the previous iteration.
153- class RotateTMEMStoreInLoop : public OpRewritePattern <TMEMTokenStoreOp > {
275+ class RotateTMEMStoreInLoop : public OpRewritePattern <ttng::TMEMStoreOp > {
154276public:
155277 using OpRewritePattern::OpRewritePattern;
156278
157- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
279+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
158280 PatternRewriter &rewriter) const override {
281+ if (!store.getDep ())
282+ return failure ();
159283 // Pattern match stores whose source comes from a loop region argument and
160284 // whose predicate is loop-invariant.
161285 scf::ForOp forOp = dyn_cast<scf::ForOp>(store->getParentOp ());
@@ -215,12 +339,14 @@ class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
215339
216340// Remove loop-carried tensor dependencies if they are the result of TMEM loads
217341// at the end of the loop by pushing the load into the next iteration.
218- class RotateTMEMLoadInLoop : public OpRewritePattern <TMEMTokenLoadOp > {
342+ class RotateTMEMLoadInLoop : public OpRewritePattern <ttng::TMEMLoadOp > {
219343public:
220344 using OpRewritePattern::OpRewritePattern;
221345
222- LogicalResult matchAndRewrite (TMEMTokenLoadOp load,
346+ LogicalResult matchAndRewrite (ttng::TMEMLoadOp load,
223347 PatternRewriter &rewriter) const override {
348+ if (!load.getDep ())
349+ return failure ();
224350 // Pattern match loads whose results are only passed into the next iteration
225351 // of a loop.
226352 scf::ForOp forOp = dyn_cast<scf::ForOp>(load->getParentOp ());
@@ -391,32 +517,55 @@ struct HoistTMEMAlloc
391517
392518 void runOnOperation () override {
393519 ModuleOp m = getOperation ();
394- SmallVector<ttng::MMAv5OpInterface> mmaOps;
395- m.walk ([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back (mmaOp); });
396- for (auto mmaOp : mmaOps) {
397- auto forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp ());
398- if (!forOp) {
399- continue ;
520+ if (!hoistOutOfIf) {
521+ SmallVector<ttng::MMAv5OpInterface> mmaOps;
522+ m.walk ([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back (mmaOp); });
523+ for (auto mmaOp : mmaOps) {
524+ auto forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp ());
525+ if (!forOp) {
526+ continue ;
527+ }
528+ hoistInvariantInputs (mmaOp, forOp);
529+
530+ // Only hoist the TMEM alloc feeding into the accumulator. Leave the
531+ // ones for the scales in the loop.
532+ auto alloc = mmaOp.getAccumulator ().getDefiningOp <TMEMTokenAllocOp>();
533+ if (!alloc || alloc->getParentRegion () != mmaOp->getParentRegion ()) {
534+ continue ;
535+ }
536+ hoistTMEMAlloc (alloc, forOp);
400537 }
401- hoistInvariantInputs (mmaOp, forOp);
402-
403- // Only hoist the TMEM alloc feeding into the accumulator. Leave the ones
404- // for the scales in the loop.
405- auto alloc = mmaOp.getAccumulator ().getDefiningOp <TMEMTokenAllocOp>();
406- if (!alloc || alloc->getParentRegion () != mmaOp->getParentRegion ()) {
407- continue ;
408- }
409- hoistTMEMAlloc (alloc, forOp);
410538 }
411539
412540 mlir::RewritePatternSet patterns (&getContext ());
413541 patterns.add <RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
414542 CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
415543 SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext ());
544+ if (hoistOutOfIf) {
545+ patterns.add <CombineTMEMStoreAndAlloc, HoistTMEMAllocOutOfIf,
546+ TMEMLoadForwarding>(&getContext ());
547+ }
416548 scf::ForOp::getCanonicalizationPatterns (patterns, &getContext ());
417549 if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
418550 llvm_unreachable (" Failed to hoist tmem_store" );
419551 }
552+
553+ // TODO: currently some code assumes that a mutable tmem alloc doesn't have
554+ // an initial value. As a workaround we break up the op in order to keep
555+ // this form for the downstream passes. We should remove this once the
556+ // downstread passes are fixed.
557+ m.walk ([&](ttng::TMEMAllocOp alloc) {
558+ if (alloc.getType ().getMutableMemory () && alloc.getSrc ()) {
559+ OpBuilder builder (alloc);
560+ builder.setInsertionPointAfter (alloc);
561+ auto store = builder.create <ttng::TMEMStoreOp>(
562+ alloc.getLoc (), builder.getType <AsyncTokenType>(),
563+ alloc.getResult (), alloc.getToken (), alloc.getSrc (),
564+ builder.create <arith::ConstantIntOp>(alloc.getLoc (), 1 , 1 ));
565+ alloc.getToken ().replaceAllUsesExcept (store.getToken (), store);
566+ alloc.getSrcMutable ().clear ();
567+ }
568+ });
420569 }
421570};
422571
0 commit comments