Skip to content

Commit e1e912b

Browse files
authored
[Coalescing]: Generalize implementation to reduce code duplication (#4863)
This PR generalizes the coalescing implementation to reduce code duplication by unifying the handling of loop operations and value propagation. - Replaces specific scf::ForOp and scf::WhileOp handling with generic LoopLikeOpInterface - Merges separate methods for propagating layouts from operation results and block arguments into a single unified approach --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 33a25b4 commit e1e912b

File tree

1 file changed

+31
-88
lines changed

1 file changed

+31
-88
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 31 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -184,79 +184,26 @@ struct CoalescePass
184184
op->dumpPretty();
185185
});
186186

187-
for (OpResult res : op->getResults())
188-
if (res == opRes)
189-
propagateLayout(op, res, layout, rewriter);
187+
propagateLayout(opRes, layout, rewriter);
190188
}
191189

192-
// Propagate the layout of the \p root operation's result to its users.
193-
void propagateLayout(Operation *op, Value opRes, Attribute layout,
194-
IRRewriter &rewriter) const {
195-
assert(op && op->getNumResults() != 0 &&
196-
"Expecting an operation yielding a result");
197-
assert(opRes &&
198-
llvm::any_of(op->getResults(),
199-
[&](OpResult res) { return res == opRes; }) &&
200-
"Expecting operation to yield 'opRes'");
201-
202-
LLVM_DEBUG({
203-
if (!opRes.getUsers().empty()) {
204-
llvm::dbgs() << "[" DEBUG_TYPE "]: "
205-
<< "Propagate layout to operations using: " << opRes
206-
<< "\n";
207-
}
208-
});
209-
210-
for (Operation *user : opRes.getUsers()) {
211-
if (filterUser(user))
212-
continue;
213-
214-
LLVM_DEBUG({
215-
llvm::dbgs() << "[" DEBUG_TYPE "]: " << "user: ";
216-
user->dumpPretty();
217-
});
218-
219-
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
220-
propagateLayoutToArgsAndBody(forOp, opRes, layout, rewriter);
221-
continue;
222-
}
223-
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
224-
propagateLayoutToArgsAndBody(whileOp, opRes, layout, rewriter);
225-
continue;
226-
}
227-
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
228-
if (auto loopOp = yieldOp->getParentOfType<LoopLikeOpInterface>()) {
229-
for (OpOperand &operand : yieldOp->getOpOperands())
230-
if (operand.get() == opRes)
231-
propagateLayoutToLoopResults(loopOp, operand.getOperandNumber(),
232-
layout, rewriter);
233-
continue;
234-
}
235-
}
236-
237-
LLVM_DEBUG({
238-
llvm::dbgs() << "[" DEBUG_TYPE "]: After propagating layout:\n";
239-
op->getParentOfType<ModuleOp>()->dumpPretty();
240-
});
241-
242-
for (OpResult res : user->getResults())
243-
changeAndPropagateLayout(user, res, layout, rewriter);
244-
}
245-
}
246-
247-
// Propagate the layout of the \p arg block argument to its users.
248-
void propagateLayout(BlockArgument arg, Attribute layout,
190+
// Propagate \p layout to users of \p val.
191+
void propagateLayout(Value val, Attribute layout,
249192
IRRewriter &rewriter) const {
250193
LLVM_DEBUG({
251-
if (!arg.getUsers().empty()) {
194+
if (!val.getUsers().empty()) {
252195
llvm::dbgs() << "[" DEBUG_TYPE "]: "
253196
<< "Propagate layout to operations using: ";
254-
arg.printAsOperand(llvm::dbgs(), {});
255-
llvm::dbgs() << "\n";
197+
if (isa<BlockArgument>(val)) {
198+
val.printAsOperand(llvm::dbgs(), {});
199+
llvm::dbgs() << "\n";
200+
} else {
201+
llvm::dbgs() << val << "\n";
202+
}
256203
}
257204
});
258205

259-
for (Operation *user : arg.getUsers()) {
206+
for (Operation *user : val.getUsers()) {
260207
if (filterUser(user))
261208
continue;
262209

@@ -266,19 +213,20 @@ struct CoalescePass
266213
});
267214

268215
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
269-
propagateLayoutToArgsAndBody(forOp, arg, layout, rewriter);
216+
propagateLayoutToArgsAndBody(forOp, val, layout, rewriter);
270217
continue;
271218
}
272219
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
273-
propagateLayoutToArgsAndBody(whileOp, arg, layout, rewriter);
220+
propagateLayoutToArgsAndBody(whileOp, val, layout, rewriter);
274221
continue;
275222
}
276223
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
277224
if (auto loopOp = yieldOp->getParentOfType<LoopLikeOpInterface>()) {
278-
for (OpOperand &operand : yieldOp->getOpOperands())
279-
if (operand.get() == arg)
280-
propagateLayoutToLoopResults(loopOp, operand.getOperandNumber(),
281-
layout, rewriter);
225+
for (OpOperand &operand : llvm::make_filter_range(
226+
yieldOp->getOpOperands(),
227+
[&val](OpOperand &operand) { return operand.get() == val; }))
228+
propagateLayoutToLoopResult(loopOp, operand.getOperandNumber(),
229+
layout, rewriter);
282230
continue;
283231
}
284232
}
@@ -288,7 +236,7 @@ struct CoalescePass
288236
for (auto [condOperand, loopArg] :
289237
llvm::zip(condOp->getOperands().drop_front(),
290238
whileOp.getAfterArguments())) {
291-
if (condOperand != arg ||
239+
if (condOperand != val ||
292240
!tt::isTensorPointerType(condOperand.getType()))
293241
continue;
294242

@@ -310,24 +258,19 @@ struct CoalescePass
310258
continue;
311259
}
312260

261+
LLVM_DEBUG({
262+
llvm::dbgs() << "[" DEBUG_TYPE "]: After propagating layout:\n";
263+
val.getParentRegion()->getParentOfType<ModuleOp>()->dumpPretty();
264+
});
265+
313266
for (OpResult res : user->getResults())
314267
changeAndPropagateLayout(user, res, layout, rewriter);
315268
}
316-
317-
LLVM_DEBUG({
318-
auto mod =
319-
arg.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
320-
llvm::dbgs() << "[" DEBUG_TYPE "]: After propagating layout:\n";
321-
mod->dumpPretty();
322-
});
323269
}
324270

325-
// Propagate the layout of the \p root operation's result to the \p loopOp
326-
// loop init argument that uses it, and transitively to the operations in the
327-
// loop body that use that argument.
328-
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
329-
OpType, scf::ForOp, scf::WhileOp>::value>>
330-
void propagateLayoutToArgsAndBody(OpType loopOp, Value opRes,
271+
// Propagate \p layout to the \p loopOp init arguments that use \p opRes, and
272+
// transitively to the operations in the loop body that use those arguments.
273+
void propagateLayoutToArgsAndBody(LoopLikeOpInterface loopOp, Value opRes,
331274
Attribute layout,
332275
IRRewriter &rewriter) const {
333276
for (auto [initArg, arg] :
@@ -354,9 +297,9 @@ struct CoalescePass
354297

355298
// Modify the \p layout to the loop's operand identified by \p resNum, and
356299
// propagate the modified loop results to its users.
357-
void propagateLayoutToLoopResults(LoopLikeOpInterface loopOp, unsigned resNum,
358-
Attribute layout,
359-
IRRewriter &rewriter) const {
300+
void propagateLayoutToLoopResult(LoopLikeOpInterface loopOp, unsigned resNum,
301+
Attribute layout,
302+
IRRewriter &rewriter) const {
360303
Value loopRes = loopOp->getResult(resNum);
361304
rewriter.modifyOpInPlace(loopOp, [&]() {
362305
assert(tt::isTensorPointerType(loopRes.getType()) &&
@@ -368,7 +311,7 @@ struct CoalescePass
368311
ptrType.getAddressSpace()));
369312
});
370313

371-
propagateLayout(loopOp, loopRes, layout, rewriter);
314+
propagateLayout(loopRes, layout, rewriter);
372315
}
373316

374317
void coalesceOp(Attribute encoding, Operation *op) {

0 commit comments

Comments
 (0)