@@ -38,12 +38,14 @@ using TMEMTokenLoadOp = HasToken<ttng::TMEMLoadOp>;
38
38
using TMEMTokenStoreOp = HasToken<ttng::TMEMStoreOp>;
39
39
using TMEMTokenAllocOp = HasToken<ttng::TMEMAllocOp>;
40
40
41
- class CombineTMEMStoreAndSelect : public OpRewritePattern <TMEMTokenStoreOp > {
41
+ class CombineTMEMStoreAndSelect : public OpRewritePattern <ttng::TMEMStoreOp > {
42
42
public:
43
43
using OpRewritePattern::OpRewritePattern;
44
44
45
- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
45
+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
46
46
PatternRewriter &rewriter) const override {
47
+ if (!store.getDep ())
48
+ return failure ();
47
49
Value src = store.getSrc ();
48
50
auto select = src.getDefiningOp <arith::SelectOp>();
49
51
if (!select) {
@@ -79,12 +81,14 @@ class CombineTMEMStoreAndSelect : public OpRewritePattern<TMEMTokenStoreOp> {
79
81
}
80
82
};
81
83
82
- class RemoveUnusedTMEMLoad : public OpRewritePattern <TMEMTokenLoadOp > {
84
+ class RemoveUnusedTMEMLoad : public OpRewritePattern <ttng::TMEMLoadOp > {
83
85
public:
84
86
using OpRewritePattern::OpRewritePattern;
85
87
86
- LogicalResult matchAndRewrite (TMEMTokenLoadOp load,
88
+ LogicalResult matchAndRewrite (ttng::TMEMLoadOp load,
87
89
PatternRewriter &rewriter) const override {
90
+ if (!load.getDep ())
91
+ return failure ();
88
92
if (!load.getResult ().use_empty ())
89
93
return failure ();
90
94
rewriter.replaceAllUsesWith (load.getToken (), load.getDep ());
@@ -93,12 +97,14 @@ class RemoveUnusedTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
93
97
};
94
98
95
99
// Load-store forwarding pattern.
96
- class CombineTMEMLoadAndStore : public OpRewritePattern <TMEMTokenStoreOp > {
100
+ class CombineTMEMLoadAndStore : public OpRewritePattern <ttng::TMEMStoreOp > {
97
101
public:
98
102
using OpRewritePattern::OpRewritePattern;
99
103
100
- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
104
+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
101
105
PatternRewriter &rewriter) const override {
106
+ if (!store.getDep ())
107
+ return failure ();
102
108
auto load = store.getDep ().getDefiningOp <HasToken<ttng::TMEMLoadOp>>();
103
109
if (!load || load.getResult () != store.getSrc () ||
104
110
load.getSrc () != store.getDst ())
@@ -108,12 +114,14 @@ class CombineTMEMLoadAndStore : public OpRewritePattern<TMEMTokenStoreOp> {
108
114
}
109
115
};
110
116
111
- class SinkTMEMLoad : public OpRewritePattern <TMEMTokenLoadOp > {
117
+ class SinkTMEMLoad : public OpRewritePattern <ttng::TMEMLoadOp > {
112
118
public:
113
119
using OpRewritePattern::OpRewritePattern;
114
120
115
- LogicalResult matchAndRewrite (TMEMTokenLoadOp load,
121
+ LogicalResult matchAndRewrite (ttng::TMEMLoadOp load,
116
122
PatternRewriter &rewriter) const override {
123
+ if (!load.getDep ())
124
+ return failure ();
117
125
auto forOp = load->getParentOfType <scf::ForOp>();
118
126
if (!forOp) {
119
127
return failure ();
@@ -148,14 +156,130 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148
156
}
149
157
};
150
158
159
+ // Combine back TMEM alloc and store. This is equivalent but gives us a more
160
+ // canonical form to do further optimizations.
161
+ class CombineTMEMStoreAndAlloc : public OpRewritePattern <ttng::TMEMStoreOp> {
162
+ public:
163
+ using OpRewritePattern::OpRewritePattern;
164
+
165
+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
166
+ PatternRewriter &rewriter) const override {
167
+ if (!store.getDep ())
168
+ return failure ();
169
+ if (!matchPattern (store.getPred (), m_One ()))
170
+ return failure ();
171
+ auto alloc = store.getDep ().getDefiningOp <TMEMTokenAllocOp>();
172
+ if (!alloc)
173
+ return failure ();
174
+ if (store.getSrc () != alloc.getResult ())
175
+ return failure ();
176
+ if (alloc->getBlock () != store->getBlock ())
177
+ return failure ();
178
+ alloc.getSrcMutable ().assign (store.getSrc ());
179
+ rewriter.replaceOp (store, alloc.getToken ());
180
+ return success ();
181
+ }
182
+ };
183
+
184
+ // Hoists a tmem alloc outside an if op like this:
185
+ // %0 = scf.if {
186
+ // %1, %token0 = tmem.alloc %init
187
+ // ...
188
+ // %2 = tmem.load %1, %token1
189
+ // scf.yield %2
190
+ // } else {
191
+ // scf.yield %init
192
+ // }
193
+ // ->
194
+ // %a, %token0 = tmem.alloc %init
195
+ // %token2 = scf.if {
196
+ //
197
+ // ...
198
+ // scf.yield %token1
199
+ // } else {
200
+ // scf.yield %token0
201
+ // }
202
+ // %2 = tmem.load %a, %token2
203
+ class HoistTMEMAllocOutOfIf : public OpRewritePattern <ttng::TMEMAllocOp> {
204
+ public:
205
+ using OpRewritePattern::OpRewritePattern;
206
+
207
+ LogicalResult matchAndRewrite (ttng::TMEMAllocOp alloc,
208
+ PatternRewriter &rewriter) const override {
209
+ if (!alloc.getToken ())
210
+ return failure ();
211
+ Value init = alloc.getSrc ();
212
+ if (!init)
213
+ return failure ();
214
+ auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp ());
215
+ if (!ifOp || !ifOp.elseBlock ())
216
+ return failure ();
217
+ auto thenOp = ifOp.thenBlock ()->getTerminator ();
218
+ auto elseOp = ifOp.elseBlock ()->getTerminator ();
219
+ SmallVector<int > yieldArgs;
220
+ for (auto [thenOperand, elseOperand] :
221
+ llvm::zip (thenOp->getOpOperands (), elseOp->getOpOperands ())) {
222
+ auto load = thenOperand.get ().getDefiningOp <TMEMTokenLoadOp>();
223
+ if (!load || load.getSrc () != alloc.getResult ())
224
+ continue ;
225
+ if (elseOperand.get () != init)
226
+ continue ;
227
+ yieldArgs.push_back (thenOperand.getOperandNumber ());
228
+ }
229
+ if (yieldArgs.empty ())
230
+ return failure ();
231
+ // Since init is used in the else terminator we know that it dominates the
232
+ // if op.
233
+ alloc->moveBefore (ifOp);
234
+ rewriter.setInsertionPointAfter (ifOp);
235
+ for (int argNo : yieldArgs) {
236
+ auto load =
237
+ cast<TMEMTokenLoadOp>(thenOp->getOperand (argNo).getDefiningOp ());
238
+ auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone (*load));
239
+ rewriter.modifyOpInPlace (ifOp, [&] {
240
+ ifOp->getResult (argNo).replaceAllUsesWith (newLoad.getResult ());
241
+ newLoad.getDepMutable ().assign (ifOp->getResult (argNo));
242
+ thenOp->setOperand (argNo, load.getToken ());
243
+ elseOp->setOperand (argNo, alloc.getToken ());
244
+ ifOp->getResult (argNo).setType (newLoad.getToken ().getType ());
245
+ });
246
+ }
247
+ return success ();
248
+ }
249
+ };
250
+
251
+ // Forward a TMEM load into the user allocation.
252
+ class TMEMLoadForwarding : public OpRewritePattern <ttng::TMEMAllocOp> {
253
+ public:
254
+ using OpRewritePattern::OpRewritePattern;
255
+
256
+ LogicalResult matchAndRewrite (ttng::TMEMAllocOp alloc,
257
+ PatternRewriter &rewriter) const override {
258
+ if (!alloc.getToken ())
259
+ return failure ();
260
+ Value init = alloc.getSrc ();
261
+ if (!init)
262
+ return failure ();
263
+ auto load = init.getDefiningOp <TMEMTokenLoadOp>();
264
+ if (!load || !load->hasOneUse () || !load.getDep ().hasOneUse ())
265
+ return failure ();
266
+ if (alloc.getType () != load.getSrc ().getType ())
267
+ return failure ();
268
+ rewriter.replaceOp (alloc, {load.getSrc (), load.getDep ()});
269
+ return success ();
270
+ }
271
+ };
272
+
151
273
// Remove loop-carried tensor dependencies if they are fed immediately into a
152
274
// TMEM store by pulling the store into the previous iteration.
153
- class RotateTMEMStoreInLoop : public OpRewritePattern <TMEMTokenStoreOp > {
275
+ class RotateTMEMStoreInLoop : public OpRewritePattern <ttng::TMEMStoreOp > {
154
276
public:
155
277
using OpRewritePattern::OpRewritePattern;
156
278
157
- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
279
+ LogicalResult matchAndRewrite (ttng::TMEMStoreOp store,
158
280
PatternRewriter &rewriter) const override {
281
+ if (!store.getDep ())
282
+ return failure ();
159
283
// Pattern match stores whose source comes from a loop region argument and
160
284
// whose predicate is loop-invariant.
161
285
scf::ForOp forOp = dyn_cast<scf::ForOp>(store->getParentOp ());
@@ -215,12 +339,14 @@ class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
215
339
216
340
// Remove loop-carried tensor dependencies if they are the result of TMEM loads
217
341
// at the end of the loop by pushing the load into the next iteration.
218
- class RotateTMEMLoadInLoop : public OpRewritePattern <TMEMTokenLoadOp > {
342
+ class RotateTMEMLoadInLoop : public OpRewritePattern <ttng::TMEMLoadOp > {
219
343
public:
220
344
using OpRewritePattern::OpRewritePattern;
221
345
222
- LogicalResult matchAndRewrite (TMEMTokenLoadOp load,
346
+ LogicalResult matchAndRewrite (ttng::TMEMLoadOp load,
223
347
PatternRewriter &rewriter) const override {
348
+ if (!load.getDep ())
349
+ return failure ();
224
350
// Pattern match loads whose results are only passed into the next iteration
225
351
// of a loop.
226
352
scf::ForOp forOp = dyn_cast<scf::ForOp>(load->getParentOp ());
@@ -391,32 +517,55 @@ struct HoistTMEMAlloc
391
517
392
518
void runOnOperation () override {
393
519
ModuleOp m = getOperation ();
394
- SmallVector<ttng::MMAv5OpInterface> mmaOps;
395
- m.walk ([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back (mmaOp); });
396
- for (auto mmaOp : mmaOps) {
397
- auto forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp ());
398
- if (!forOp) {
399
- continue ;
520
+ if (!hoistOutOfIf) {
521
+ SmallVector<ttng::MMAv5OpInterface> mmaOps;
522
+ m.walk ([&](ttng::MMAv5OpInterface mmaOp) { mmaOps.push_back (mmaOp); });
523
+ for (auto mmaOp : mmaOps) {
524
+ auto forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp ());
525
+ if (!forOp) {
526
+ continue ;
527
+ }
528
+ hoistInvariantInputs (mmaOp, forOp);
529
+
530
+ // Only hoist the TMEM alloc feeding into the accumulator. Leave the
531
+ // ones for the scales in the loop.
532
+ auto alloc = mmaOp.getAccumulator ().getDefiningOp <TMEMTokenAllocOp>();
533
+ if (!alloc || alloc->getParentRegion () != mmaOp->getParentRegion ()) {
534
+ continue ;
535
+ }
536
+ hoistTMEMAlloc (alloc, forOp);
400
537
}
401
- hoistInvariantInputs (mmaOp, forOp);
402
-
403
- // Only hoist the TMEM alloc feeding into the accumulator. Leave the ones
404
- // for the scales in the loop.
405
- auto alloc = mmaOp.getAccumulator ().getDefiningOp <TMEMTokenAllocOp>();
406
- if (!alloc || alloc->getParentRegion () != mmaOp->getParentRegion ()) {
407
- continue ;
408
- }
409
- hoistTMEMAlloc (alloc, forOp);
410
538
}
411
539
412
540
mlir::RewritePatternSet patterns (&getContext ());
413
541
patterns.add <RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
414
542
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
415
543
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext ());
544
+ if (hoistOutOfIf) {
545
+ patterns.add <CombineTMEMStoreAndAlloc, HoistTMEMAllocOutOfIf,
546
+ TMEMLoadForwarding>(&getContext ());
547
+ }
416
548
scf::ForOp::getCanonicalizationPatterns (patterns, &getContext ());
417
549
if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
418
550
llvm_unreachable (" Failed to hoist tmem_store" );
419
551
}
552
+
553
+ // TODO: currently some code assumes that a mutable tmem alloc doesn't have
554
+ // an initial value. As a workaround we break up the op in order to keep
555
+ // this form for the downstream passes. We should remove this once the
556
+ // downstread passes are fixed.
557
+ m.walk ([&](ttng::TMEMAllocOp alloc) {
558
+ if (alloc.getType ().getMutableMemory () && alloc.getSrc ()) {
559
+ OpBuilder builder (alloc);
560
+ builder.setInsertionPointAfter (alloc);
561
+ auto store = builder.create <ttng::TMEMStoreOp>(
562
+ alloc.getLoc (), builder.getType <AsyncTokenType>(),
563
+ alloc.getResult (), alloc.getToken (), alloc.getSrc (),
564
+ builder.create <arith::ConstantIntOp>(alloc.getLoc (), 1 , 1 ));
565
+ alloc.getToken ().replaceAllUsesExcept (store.getToken (), store);
566
+ alloc.getSrcMutable ().clear ();
567
+ }
568
+ });
420
569
}
421
570
};
422
571
0 commit comments