Skip to content

Commit 754ec70

Browse files
committed
Cleanup
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 949256e commit 754ec70

File tree

1 file changed

+79
-87
lines changed

1 file changed

+79
-87
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 79 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,77 @@ struct CoalescePass
171171
return false;
172172
}
173173

174-
// Propagate the layout to \p root operation's result to the \p forOp loop
174+
// Change the \p layout of the \p op result and propagate the new result type
175+
// to its users.
176+
void changeAndPropagateLayout(Operation *op, Attribute layout,
177+
IRRewriter &rewriter) const {
178+
assert(op && op->getNumResults() == 1 &&
179+
"Expecting operation yielding a result");
180+
181+
rewriter.modifyOpInPlace(op, [&]() {
182+
Value res = op->getOpResult(0);
183+
assert(tt::isTensorPointerType(res.getType()) &&
184+
"Expecting a block pointer");
185+
186+
auto ptrType = cast<tt::PointerType>(res.getType());
187+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
188+
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
189+
ptrType.getAddressSpace()));
190+
});
191+
LDBG("Coalesced op: " << *op);
192+
193+
propagateLayout(op, layout, rewriter);
194+
}
195+
196+
// Propagate the layout of the \p root operation's result to its users.
197+
void propagateLayout(Operation *root, Attribute layout,
198+
IRRewriter &rewriter) const {
199+
assert(root->getNumResults() != 0 &&
200+
"Expecting an operation yielding a result");
201+
202+
LDBG("root: " << *root);
203+
for (Operation *user : root->getUsers()) {
204+
if (filterUser(user))
205+
continue;
206+
207+
LDBG("root's user: " << *user << "\n");
208+
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
209+
propagateLayoutToArgsAndBody(forOp, root, layout, rewriter);
210+
continue;
211+
}
212+
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
213+
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
214+
propagateLayoutToLoopResults(forOp, layout, rewriter);
215+
continue;
216+
}
217+
changeAndPropagateLayout(user, layout, rewriter);
218+
}
219+
}
220+
221+
// Propagate the layout of the \p arg block argument to its users.
222+
void propagateLayout(BlockArgument arg, Attribute layout,
223+
IRRewriter &rewriter) const {
224+
LDBG("arg: " << arg);
225+
for (Operation *user : arg.getUsers()) {
226+
if (filterUser(user))
227+
continue;
228+
229+
LDBG("arg's user: " << *user << "\n");
230+
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
231+
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
232+
propagateLayoutToLoopResults(forOp, layout, rewriter);
233+
continue;
234+
}
235+
changeAndPropagateLayout(user, layout, rewriter);
236+
}
237+
}
238+
239+
// Propagate the layout of the \p root operation's result to the \p forOp loop
175240
// init argument that uses it, and transitively to the operations in the loop
176241
// body that use that argument.
177-
static void propagate(scf::ForOp forOp, Operation *root, Attribute layout,
178-
IRRewriter &rewriter) {
242+
void propagateLayoutToArgsAndBody(scf::ForOp forOp, Operation *root,
243+
Attribute layout,
244+
IRRewriter &rewriter) const {
179245
assert(llvm::any_of(root->getUsers(),
180246
[&](Operation *user) { return user == forOp; }) &&
181247
"Expecting the loop to be a user of the root operation");
@@ -202,8 +268,8 @@ struct CoalescePass
202268

203269
// Modify the given loop \p forOp and propagate the result of the enclosing
204270
// loop.
205-
static void propagate(scf::ForOp forOp, Attribute layout,
206-
IRRewriter &rewriter) {
271+
void propagateLayoutToLoopResults(scf::ForOp forOp, Attribute layout,
272+
IRRewriter &rewriter) const {
207273
Operation *yieldOp = forOp.getBody()->getTerminator();
208274

209275
rewriter.modifyOpInPlace(forOp, [&]() {
@@ -229,69 +295,6 @@ struct CoalescePass
229295
propagateLayout(forOp, layout, rewriter);
230296
}
231297

232-
static void propagateLayout(BlockArgument arg, Attribute layout,
233-
IRRewriter &rewriter) {
234-
LDBG("arg: " << arg);
235-
for (Operation *user : arg.getUsers()) {
236-
LDBG("arg's user: " << *user << "\n");
237-
if (filterUser(user)) {
238-
continue;
239-
}
240-
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
241-
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
242-
propagate(forOp, layout, rewriter);
243-
continue;
244-
}
245-
changeAndPropagateLayout(user, layout, rewriter);
246-
}
247-
}
248-
249-
static void propagateLayout(Operation *root, Attribute layout,
250-
IRRewriter &rewriter) {
251-
assert(root->getNumResults() != 0 &&
252-
"Expecting an operation yielding a result");
253-
254-
LDBG("root: " << *root);
255-
for (Operation *user : root->getUsers()) {
256-
LDBG("root's user: " << *user << "\n");
257-
if (filterUser(user)) {
258-
continue;
259-
}
260-
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
261-
propagate(forOp, root, layout, rewriter);
262-
continue;
263-
}
264-
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
265-
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
266-
propagate(forOp, layout, rewriter);
267-
continue;
268-
}
269-
changeAndPropagateLayout(user, layout, rewriter);
270-
}
271-
}
272-
273-
// Change the \p layout of the \p op result and propagate the new result type
274-
// to its users.
275-
static void changeAndPropagateLayout(Operation *op, Attribute layout,
276-
IRRewriter &rewriter) {
277-
assert(op && op->getNumResults() == 1 &&
278-
"Expecting operation yielding a result");
279-
280-
rewriter.modifyOpInPlace(op, [&]() {
281-
Value res = op->getOpResult(0);
282-
assert(tt::isTensorPointerType(res.getType()) &&
283-
"Expecting a block pointer");
284-
285-
auto ptrType = cast<tt::PointerType>(res.getType());
286-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
287-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
288-
ptrType.getAddressSpace()));
289-
});
290-
LDBG("Coalesced op: " << *op);
291-
292-
propagateLayout(op, layout, rewriter);
293-
}
294-
295298
void coalesceOp(Attribute encoding, Operation *op) {
296299
LDBG("Coalescing op: " << *op);
297300

@@ -316,8 +319,7 @@ struct CoalescePass
316319
"Expecting operand to have blocked pointer type");
317320
auto defOp = findDefiningMakeTensorPtrOp(operand);
318321
assert(defOp && "Expected a make_tensor_ptr operation");
319-
320-
llvm::errs() << "Found make_tensor_ptr definition: " << *defOp << "\n";
322+
LDBG("Found make_tensor_ptr definition: " << *defOp);
321323
changeAndPropagateLayout(*defOp, encoding, rewriter);
322324
newArgs.push_back(operand);
323325
}
@@ -326,8 +328,7 @@ struct CoalescePass
326328
// Convert output types
327329
SmallVector<Type, 4> newTypes;
328330
for (auto t : op->getResultTypes()) {
329-
bool isAsync = isa<ttg::AsyncCopyGlobalToLocalOp>(op);
330-
assert(!isAsync &&
331+
assert(!isa<ttg::AsyncCopyGlobalToLocalOp>(op) &&
331332
"AsyncCopyGlobalToLocalOp not supported for Intel GPU");
332333
newTypes.push_back(getNewType(cast<RankedTensorType>(t), encoding));
333334
}
@@ -379,7 +380,8 @@ struct CoalescePass
379380
});
380381

381382
LLVM_DEBUG({
382-
DBGS() << "layoutMap:" << "\n";
383+
DBGS() << "layoutMap:"
384+
<< "\n";
383385
for (auto [op, encoding] : layoutMap) {
384386
DBGS() << "op: " << *op << "\n";
385387
DBGS() << "encoding: " << encoding << "\n";
@@ -398,20 +400,10 @@ struct CoalescePass
398400
coalesceOp(layout, op);
399401
}
400402

401-
if (failed(verify(moduleOp))) {
402-
llvm::errs() << "Module verification failed.\n";
403-
llvm::errs() << "mod: " << moduleOp << "\n";
404-
for (Operation &op1 : moduleOp.getOps()) {
405-
if (isa<tt::FuncOp>(op1)) {
406-
for (Operation &op2 : cast<tt::FuncOp>(op1).getOps()) {
407-
if (failed(verify(&op2))) {
408-
llvm::errs() << "op2: " << op2 << "\n";
409-
llvm::errs() << "Operation verification failed.\n";
410-
assert(false);
411-
}
412-
}
413-
}
414-
}
403+
// Verify the module's functions after the transformation.
404+
for (auto op : moduleOp.getOps<tt::FuncOp>()) {
405+
for (Operation &op1 : op.getOps())
406+
assert(succeeded(verify(&op1)));
415407
}
416408
}
417409
};

0 commit comments

Comments
 (0)