@@ -184,79 +184,26 @@ struct CoalescePass
184
184
op->dumpPretty ();
185
185
});
186
186
187
- for (OpResult res : op->getResults ())
188
- if (res == opRes)
189
- propagateLayout (op, res, layout, rewriter);
187
+ propagateLayout (opRes, layout, rewriter);
190
188
}
191
189
192
- // Propagate the layout of the \p root operation's result to its users.
193
- void propagateLayout (Operation *op, Value opRes, Attribute layout,
194
- IRRewriter &rewriter) const {
195
- assert (op && op->getNumResults () != 0 &&
196
- " Expecting an operation yielding a result" );
197
- assert (opRes &&
198
- llvm::any_of (op->getResults (),
199
- [&](OpResult res) { return res == opRes; }) &&
200
- " Expecting operation to yield 'opRes'" );
201
-
202
- LLVM_DEBUG ({
203
- if (!opRes.getUsers ().empty ()) {
204
- llvm::dbgs () << " [" DEBUG_TYPE " ]: "
205
- << " Propagate layout to operations using: " << opRes
206
- << " \n " ;
207
- }
208
- });
209
-
210
- for (Operation *user : opRes.getUsers ()) {
211
- if (filterUser (user))
212
- continue ;
213
-
214
- LLVM_DEBUG ({
215
- llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " user: " ;
216
- user->dumpPretty ();
217
- });
218
-
219
- if (auto forOp = dyn_cast<scf::ForOp>(user)) {
220
- propagateLayoutToArgsAndBody (forOp, opRes, layout, rewriter);
221
- continue ;
222
- }
223
- if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
224
- propagateLayoutToArgsAndBody (whileOp, opRes, layout, rewriter);
225
- continue ;
226
- }
227
- if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
228
- if (auto loopOp = yieldOp->getParentOfType <LoopLikeOpInterface>()) {
229
- for (OpOperand &operand : yieldOp->getOpOperands ())
230
- if (operand.get () == opRes)
231
- propagateLayoutToLoopResults (loopOp, operand.getOperandNumber (),
232
- layout, rewriter);
233
- continue ;
234
- }
235
- }
236
-
237
- LLVM_DEBUG ({
238
- llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
239
- op->getParentOfType <ModuleOp>()->dumpPretty ();
240
- });
241
-
242
- for (OpResult res : user->getResults ())
243
- changeAndPropagateLayout (user, res, layout, rewriter);
244
- }
245
- }
246
-
247
- // Propagate the layout of the \p arg block argument to its users.
248
- void propagateLayout (BlockArgument arg, Attribute layout,
190
+ // Propagate \p layout to users of \p val.
191
+ void propagateLayout (Value val, Attribute layout,
249
192
IRRewriter &rewriter) const {
250
193
LLVM_DEBUG ({
251
- if (!arg .getUsers ().empty ()) {
194
+ if (!val .getUsers ().empty ()) {
252
195
llvm::dbgs () << " [" DEBUG_TYPE " ]: "
253
196
<< " Propagate layout to operations using: " ;
254
- arg.printAsOperand (llvm::dbgs (), {});
255
- llvm::dbgs () << " \n " ;
197
+ if (isa<BlockArgument>(val)) {
198
+ val.printAsOperand (llvm::dbgs (), {});
199
+ llvm::dbgs () << " \n " ;
200
+ } else {
201
+ llvm::dbgs () << val << " \n " ;
202
+ }
256
203
}
257
204
});
258
205
259
- for (Operation *user : arg .getUsers ()) {
206
+ for (Operation *user : val .getUsers ()) {
260
207
if (filterUser (user))
261
208
continue ;
262
209
@@ -266,19 +213,20 @@ struct CoalescePass
266
213
});
267
214
268
215
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
269
- propagateLayoutToArgsAndBody (forOp, arg , layout, rewriter);
216
+ propagateLayoutToArgsAndBody (forOp, val , layout, rewriter);
270
217
continue ;
271
218
}
272
219
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
273
- propagateLayoutToArgsAndBody (whileOp, arg , layout, rewriter);
220
+ propagateLayoutToArgsAndBody (whileOp, val , layout, rewriter);
274
221
continue ;
275
222
}
276
223
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
277
224
if (auto loopOp = yieldOp->getParentOfType <LoopLikeOpInterface>()) {
278
- for (OpOperand &operand : yieldOp->getOpOperands ())
279
- if (operand.get () == arg)
280
- propagateLayoutToLoopResults (loopOp, operand.getOperandNumber (),
281
- layout, rewriter);
225
+ for (OpOperand &operand : llvm::make_filter_range (
226
+ yieldOp->getOpOperands (),
227
+ [&val](OpOperand &operand) { return operand.get () == val; }))
228
+ propagateLayoutToLoopResult (loopOp, operand.getOperandNumber (),
229
+ layout, rewriter);
282
230
continue ;
283
231
}
284
232
}
@@ -288,7 +236,7 @@ struct CoalescePass
288
236
for (auto [condOperand, loopArg] :
289
237
llvm::zip (condOp->getOperands ().drop_front (),
290
238
whileOp.getAfterArguments ())) {
291
- if (condOperand != arg ||
239
+ if (condOperand != val ||
292
240
!tt::isTensorPointerType (condOperand.getType ()))
293
241
continue ;
294
242
@@ -310,24 +258,19 @@ struct CoalescePass
310
258
continue ;
311
259
}
312
260
261
+ LLVM_DEBUG ({
262
+ llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
263
+ val.getParentRegion ()->getParentOfType <ModuleOp>()->dumpPretty ();
264
+ });
265
+
313
266
for (OpResult res : user->getResults ())
314
267
changeAndPropagateLayout (user, res, layout, rewriter);
315
268
}
316
-
317
- LLVM_DEBUG ({
318
- auto mod =
319
- arg.getParentBlock ()->getParentOp ()->getParentOfType <ModuleOp>();
320
- llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
321
- mod->dumpPretty ();
322
- });
323
269
}
324
270
325
- // Propagate the layout of the \p root operation's result to the \p loopOp
326
- // loop init argument that uses it, and transitively to the operations in the
327
- // loop body that use that argument.
328
- template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
329
- OpType, scf::ForOp, scf::WhileOp>::value>>
330
- void propagateLayoutToArgsAndBody (OpType loopOp, Value opRes,
271
+ // Propagate \p layout to the \p loopOp init arguments that use \p opRes, and
272
+ // transitively to the operations in the loop body that use those arguments.
273
+ void propagateLayoutToArgsAndBody (LoopLikeOpInterface loopOp, Value opRes,
331
274
Attribute layout,
332
275
IRRewriter &rewriter) const {
333
276
for (auto [initArg, arg] :
@@ -354,9 +297,9 @@ struct CoalescePass
354
297
355
298
// Modify the \p layout to the loop's operand identified by \p resNum, and
356
299
// propagate the modified loop results to its users.
357
- void propagateLayoutToLoopResults (LoopLikeOpInterface loopOp, unsigned resNum,
358
- Attribute layout,
359
- IRRewriter &rewriter) const {
300
+ void propagateLayoutToLoopResult (LoopLikeOpInterface loopOp, unsigned resNum,
301
+ Attribute layout,
302
+ IRRewriter &rewriter) const {
360
303
Value loopRes = loopOp->getResult (resNum);
361
304
rewriter.modifyOpInPlace (loopOp, [&]() {
362
305
assert (tt::isTensorPointerType (loopRes.getType ()) &&
@@ -368,7 +311,7 @@ struct CoalescePass
368
311
ptrType.getAddressSpace ()));
369
312
});
370
313
371
- propagateLayout (loopOp, loopRes, layout, rewriter);
314
+ propagateLayout (loopRes, layout, rewriter);
372
315
}
373
316
374
317
void coalesceOp (Attribute encoding, Operation *op) {
0 commit comments