@@ -148,94 +148,6 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148
148
}
149
149
};
150
150
151
- // Combine back TMEM alloc and store. This is equivalent but gives us a more
152
- // canonical form to do further optimizations.
153
- class CombineTMEMStoreAndAlloc : public OpRewritePattern <TMEMTokenStoreOp> {
154
- public:
155
- using OpRewritePattern::OpRewritePattern;
156
-
157
- LogicalResult matchAndRewrite (TMEMTokenStoreOp store,
158
- PatternRewriter &rewriter) const override {
159
- if (!matchPattern (store.getPred (), m_One ()))
160
- return failure ();
161
- auto alloc = store.getDep ().getDefiningOp <TMEMTokenAllocOp>();
162
- if (!alloc)
163
- return failure ();
164
- if (alloc->getBlock () != store->getBlock ())
165
- return failure ();
166
- alloc.getSrcMutable ().assign (store.getSrc ());
167
- rewriter.replaceOp (store, alloc.getToken ());
168
- return success ();
169
- }
170
- };
171
-
172
- // Hoists a tmem alloc outside an if op like this:
173
- // %0 = scf.if {
174
- // %1, %token0 = tmem.alloc %init
175
- // ...
176
- // %2 = tmem.load %1, %token1
177
- // scf.yield %2
178
- // } else {
179
- // scf.yield %init
180
- // }
181
- // ->
182
- // %a, %token0 = tmem.alloc %init
183
- // %token2 = scf.if {
184
- //
185
- // ...
186
- // scf.yield %token1
187
- // } else {
188
- // scf.yield %token0
189
- // }
190
- // %2 = tmem.load %a, %token2
191
- class HoistTMEMAllocOutOfIf : public OpRewritePattern <ttng::TMEMAllocOp> {
192
- public:
193
- using OpRewritePattern::OpRewritePattern;
194
-
195
- LogicalResult matchAndRewrite (ttng::TMEMAllocOp alloc,
196
- PatternRewriter &rewriter) const override {
197
- if (!alloc.getToken ())
198
- return failure ();
199
- Value init = alloc.getSrc ();
200
- if (!init)
201
- return failure ();
202
- auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp ());
203
- if (!ifOp)
204
- return failure ();
205
- auto thenOp = ifOp.thenBlock ()->getTerminator ();
206
- auto elseOp = ifOp.elseBlock ()->getTerminator ();
207
- SmallVector<int > yieldArgs;
208
- for (auto [thenOperand, elseOperand] :
209
- llvm::zip (thenOp->getOpOperands (), elseOp->getOpOperands ())) {
210
- auto load = thenOperand.get ().getDefiningOp <TMEMTokenLoadOp>();
211
- if (!load || load.getSrc () != alloc.getResult ())
212
- continue ;
213
- if (elseOperand.get () != init)
214
- continue ;
215
- yieldArgs.push_back (thenOperand.getOperandNumber ());
216
- }
217
- if (yieldArgs.empty ())
218
- return failure ();
219
- // Since init is used in the else terminator we know that it dominates the
220
- // if op.
221
- alloc->moveBefore (ifOp);
222
- rewriter.setInsertionPointAfter (ifOp);
223
- for (int argNo : yieldArgs) {
224
- auto load =
225
- cast<TMEMTokenLoadOp>(thenOp->getOperand (argNo).getDefiningOp ());
226
- auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone (*load));
227
- rewriter.modifyOpInPlace (ifOp, [&] {
228
- ifOp->getResult (argNo).replaceAllUsesWith (newLoad.getResult ());
229
- newLoad.getDepMutable ().assign (ifOp->getResult (argNo));
230
- thenOp->setOperand (argNo, load.getToken ());
231
- elseOp->setOperand (argNo, alloc.getToken ());
232
- ifOp->getResult (argNo).setType (newLoad.getToken ().getType ());
233
- });
234
- }
235
- return success ();
236
- }
237
- };
238
-
239
151
// Remove loop-carried tensor dependencies if they are fed immediately into a
240
152
// TMEM store by pulling the store into the previous iteration.
241
153
class RotateTMEMStoreInLoop : public OpRewritePattern <TMEMTokenStoreOp> {
@@ -500,29 +412,11 @@ struct HoistTMEMAlloc
500
412
mlir::RewritePatternSet patterns (&getContext ());
501
413
patterns.add <RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
502
414
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
503
- SinkTMEMLoad, RemoveUnusedTMEMLoad, CombineTMEMStoreAndAlloc,
504
- HoistTMEMAllocOutOfIf>(&getContext ());
415
+ SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext ());
505
416
scf::ForOp::getCanonicalizationPatterns (patterns, &getContext ());
506
417
if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
507
418
llvm_unreachable (" Failed to hoist tmem_store" );
508
419
}
509
-
510
- // TODO: currently some code assumes that a mutable tmem alloc doesn't have
511
- // an initial value. As a workaround we break up the op in order to keep
512
- // this form for the downstream passes. We should remove this once the
513
- // downstread passes are fixed.
514
- m.walk ([&](ttng::TMEMAllocOp alloc) {
515
- if (alloc.getType ().getMutableMemory () && alloc.getSrc ()) {
516
- OpBuilder builder (alloc);
517
- builder.setInsertionPointAfter (alloc);
518
- auto store = builder.create <ttng::TMEMStoreOp>(
519
- alloc.getLoc (), builder.getType <AsyncTokenType>(),
520
- alloc.getResult (), alloc.getToken (), alloc.getSrc (),
521
- builder.create <arith::ConstantIntOp>(alloc.getLoc (), 1 , 1 ));
522
- alloc.getToken ().replaceAllUsesExcept (store.getToken (), store);
523
- alloc.getSrcMutable ().clear ();
524
- }
525
- });
526
420
}
527
421
};
528
422
0 commit comments