@@ -184,79 +184,26 @@ struct CoalescePass
184184 op->dumpPretty ();
185185 });
186186
187- for (OpResult res : op->getResults ())
188- if (res == opRes)
189- propagateLayout (op, res, layout, rewriter);
187+ propagateLayout (opRes, layout, rewriter);
190188 }
191189
192- // Propagate the layout of the \p root operation's result to its users.
193- void propagateLayout (Operation *op, Value opRes, Attribute layout,
194- IRRewriter &rewriter) const {
195- assert (op && op->getNumResults () != 0 &&
196- " Expecting an operation yielding a result" );
197- assert (opRes &&
198- llvm::any_of (op->getResults (),
199- [&](OpResult res) { return res == opRes; }) &&
200- " Expecting operation to yield 'opRes'" );
201-
202- LLVM_DEBUG ({
203- if (!opRes.getUsers ().empty ()) {
204- llvm::dbgs () << " [" DEBUG_TYPE " ]: "
205- << " Propagate layout to operations using: " << opRes
206- << " \n " ;
207- }
208- });
209-
210- for (Operation *user : opRes.getUsers ()) {
211- if (filterUser (user))
212- continue ;
213-
214- LLVM_DEBUG ({
215- llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " user: " ;
216- user->dumpPretty ();
217- });
218-
219- if (auto forOp = dyn_cast<scf::ForOp>(user)) {
220- propagateLayoutToArgsAndBody (forOp, opRes, layout, rewriter);
221- continue ;
222- }
223- if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
224- propagateLayoutToArgsAndBody (whileOp, opRes, layout, rewriter);
225- continue ;
226- }
227- if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
228- if (auto loopOp = yieldOp->getParentOfType <LoopLikeOpInterface>()) {
229- for (OpOperand &operand : yieldOp->getOpOperands ())
230- if (operand.get () == opRes)
231- propagateLayoutToLoopResults (loopOp, operand.getOperandNumber (),
232- layout, rewriter);
233- continue ;
234- }
235- }
236-
237- LLVM_DEBUG ({
238- llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
239- op->getParentOfType <ModuleOp>()->dumpPretty ();
240- });
241-
242- for (OpResult res : user->getResults ())
243- changeAndPropagateLayout (user, res, layout, rewriter);
244- }
245- }
246-
247- // Propagate the layout of the \p arg block argument to its users.
248- void propagateLayout (BlockArgument arg, Attribute layout,
190+ // Propagate \p layout to users of \p val.
191+ void propagateLayout (Value val, Attribute layout,
249192 IRRewriter &rewriter) const {
250193 LLVM_DEBUG ({
251- if (!arg .getUsers ().empty ()) {
194+ if (!val .getUsers ().empty ()) {
252195 llvm::dbgs () << " [" DEBUG_TYPE " ]: "
253196 << " Propagate layout to operations using: " ;
254- arg.printAsOperand (llvm::dbgs (), {});
255- llvm::dbgs () << " \n " ;
197+ if (isa<BlockArgument>(val)) {
198+ val.printAsOperand (llvm::dbgs (), {});
199+ llvm::dbgs () << " \n " ;
200+ } else {
201+ llvm::dbgs () << val << " \n " ;
202+ }
256203 }
257204 });
258205
259- for (Operation *user : arg .getUsers ()) {
206+ for (Operation *user : val .getUsers ()) {
260207 if (filterUser (user))
261208 continue ;
262209
@@ -266,19 +213,20 @@ struct CoalescePass
266213 });
267214
268215 if (auto forOp = dyn_cast<scf::ForOp>(user)) {
269- propagateLayoutToArgsAndBody (forOp, arg , layout, rewriter);
216+ propagateLayoutToArgsAndBody (forOp, val , layout, rewriter);
270217 continue ;
271218 }
272219 if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
273- propagateLayoutToArgsAndBody (whileOp, arg , layout, rewriter);
220+ propagateLayoutToArgsAndBody (whileOp, val , layout, rewriter);
274221 continue ;
275222 }
276223 if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
277224 if (auto loopOp = yieldOp->getParentOfType <LoopLikeOpInterface>()) {
278- for (OpOperand &operand : yieldOp->getOpOperands ())
279- if (operand.get () == arg)
280- propagateLayoutToLoopResults (loopOp, operand.getOperandNumber (),
281- layout, rewriter);
225+ for (OpOperand &operand : llvm::make_filter_range (
226+ yieldOp->getOpOperands (),
227+ [&val](OpOperand &operand) { return operand.get () == val; }))
228+ propagateLayoutToLoopResult (loopOp, operand.getOperandNumber (),
229+ layout, rewriter);
282230 continue ;
283231 }
284232 }
@@ -288,7 +236,7 @@ struct CoalescePass
288236 for (auto [condOperand, loopArg] :
289237 llvm::zip (condOp->getOperands ().drop_front (),
290238 whileOp.getAfterArguments ())) {
291- if (condOperand != arg ||
239+ if (condOperand != val ||
292240 !tt::isTensorPointerType (condOperand.getType ()))
293241 continue ;
294242
@@ -310,24 +258,19 @@ struct CoalescePass
310258 continue ;
311259 }
312260
261+ LLVM_DEBUG ({
262+ llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
263+ val.getParentRegion ()->getParentOfType <ModuleOp>()->dumpPretty ();
264+ });
265+
313266 for (OpResult res : user->getResults ())
314267 changeAndPropagateLayout (user, res, layout, rewriter);
315268 }
316-
317- LLVM_DEBUG ({
318- auto mod =
319- arg.getParentBlock ()->getParentOp ()->getParentOfType <ModuleOp>();
320- llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
321- mod->dumpPretty ();
322- });
323269 }
324270
325- // Propagate the layout of the \p root operation's result to the \p loopOp
326- // loop init argument that uses it, and transitively to the operations in the
327- // loop body that use that argument.
328- template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
329- OpType, scf::ForOp, scf::WhileOp>::value>>
330- void propagateLayoutToArgsAndBody (OpType loopOp, Value opRes,
271+ // Propagate \p layout to the \p loopOp init arguments that use \p opRes, and
272+ // transitively to the operations in the loop body that use those arguments.
273+ void propagateLayoutToArgsAndBody (LoopLikeOpInterface loopOp, Value opRes,
331274 Attribute layout,
332275 IRRewriter &rewriter) const {
333276 for (auto [initArg, arg] :
@@ -354,9 +297,9 @@ struct CoalescePass
354297
355298 // Modify the \p layout to the loop's operand identified by \p resNum, and
356299 // propagate the modified loop results to its users.
357- void propagateLayoutToLoopResults (LoopLikeOpInterface loopOp, unsigned resNum,
358- Attribute layout,
359- IRRewriter &rewriter) const {
300+ void propagateLayoutToLoopResult (LoopLikeOpInterface loopOp, unsigned resNum,
301+ Attribute layout,
302+ IRRewriter &rewriter) const {
360303 Value loopRes = loopOp->getResult (resNum);
361304 rewriter.modifyOpInPlace (loopOp, [&]() {
362305 assert (tt::isTensorPointerType (loopRes.getType ()) &&
@@ -368,7 +311,7 @@ struct CoalescePass
368311 ptrType.getAddressSpace ()));
369312 });
370313
371- propagateLayout (loopOp, loopRes, layout, rewriter);
314+ propagateLayout (loopRes, layout, rewriter);
372315 }
373316
374317 void coalesceOp (Attribute encoding, Operation *op) {
0 commit comments