Skip to content

Commit d9de8e7

Browse files
committed
Fix tutorial assertion
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent c3fdbba commit d9de8e7

File tree

1 file changed

+43
-26
lines changed

1 file changed

+43
-26
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,15 @@ struct CoalescePass
173173
static void propagateLayout(BlockArgument arg, Attribute layout,
174174
IRRewriter &rewriter) {
175175
llvm::errs() << "arg: " << arg << "\n";
176-
for (Operation *user : arg.getUsers()) {
177-
llvm::errs() << "user: " << *user << "\n\n";
176+
177+
auto users = arg.getUsers();
178+
if (users.empty()) {
179+
llvm::errs() << "arg has no users\n";
180+
return;
181+
}
182+
183+
for (Operation *user : users) {
184+
llvm::errs() << "arg's user: " << *user << "\n\n";
178185
if (filterUser(user)) {
179186
llvm::errs() << "SKIP\n";
180187
continue;
@@ -218,9 +225,15 @@ struct CoalescePass
218225
assert(root && root->getNumResults() != 0 &&
219226
"Expecting an operation yielding a result");
220227

221-
// llvm::errs() << "root: " << *root << "\n\n";
222-
for (Operation *user : root->getUsers()) {
223-
llvm::errs() << "user: " << *user << "\n\n";
228+
llvm::errs() << "root: " << *root << "\n";
229+
auto users = root->getUsers();
230+
if (users.empty()) {
231+
llvm::errs() << "root has no users\n";
232+
return;
233+
}
234+
235+
for (Operation *user : users) {
236+
llvm::errs() << "root's user: " << *user << "\n\n";
224237
if (filterUser(user)) {
225238
llvm::errs() << "SKIP\n";
226239
continue;
@@ -262,7 +275,6 @@ struct CoalescePass
262275
if (res == loopArg && tt::isTensorPointerType(res.getType())) {
263276
llvm::errs() << "arg: " << arg << "\n";
264277
llvm::errs() << "loopArg: " << loopArg << "\n";
265-
llvm::errs() << "arg type: " << arg.getType() << "\n";
266278

267279
// Modify the layout of the loop init argument...
268280
tt::PointerType ptrType = cast<tt::PointerType>(arg.getType());
@@ -309,7 +321,7 @@ struct CoalescePass
309321
}
310322

311323
void coalesceOp(Attribute encoding, Operation *op) {
312-
llvm::errs() << "Coalescing op: " << *op << "\n";
324+
LDBG("Coalescing op: " << *op);
313325

314326
OpBuilder builder(op);
315327
IRRewriter rewriter(builder);
@@ -362,9 +374,11 @@ struct CoalescePass
362374
}
363375
op->getResult(i).replaceAllUsesWith(newResult);
364376
}
377+
378+
LDBG("Old op: " << *op);
379+
LDBG("newOp: " << *newOp);
365380
op->erase();
366381

367-
llvm::errs() << "newOp: " << *newOp << "\n";
368382
assert(succeeded(verify(newOp)) && "Operation verification failed");
369383
}
370384

@@ -399,12 +413,15 @@ struct CoalescePass
399413
layoutMap);
400414
});
401415

402-
llvm::errs() << "layoutMap:\n";
403-
for (auto [op, encoding] : layoutMap) {
404-
llvm::errs() << "op: " << *op << "\n";
405-
llvm::errs() << "encoding: " << encoding << "\n";
406-
}
407-
llvm::errs() << "\n";
416+
LLVM_DEBUG({
417+
DBGS() << "layoutMap:"
418+
<< "\n";
419+
for (auto [op, encoding] : layoutMap) {
420+
DBGS() << "op: " << *op << "\n";
421+
DBGS() << "encoding: " << encoding << "\n";
422+
}
423+
llvm::errs() << "\n\n";
424+
});
408425

409426
// For each memory op that has a layout L1:
410427
// 1. Create a coalesced memory layout L2 of the pointer operands
@@ -415,22 +432,22 @@ struct CoalescePass
415432
// 5. Replace all the uses of the original memory op by the new one
416433
for (auto [op, layout] : layoutMap) {
417434
coalesceOp(layout, op);
418-
if (failed(verify(moduleOp))) {
419-
for (Operation &op1 : moduleOp.getOps()) {
420-
if (isa<tt::FuncOp>(op1)) {
421-
for (Operation &op2 : cast<tt::FuncOp>(op1).getOps()) {
422-
if (failed(verify(&op2))) {
423-
llvm::errs() << "op2: " << op2 << "\n";
424-
llvm::errs() << "Operation verification failed.\n";
425-
}
435+
}
436+
437+
if (failed(verify(moduleOp))) {
438+
llvm::errs() << "Module verification failed.\n";
439+
llvm::errs() << "mod: " << moduleOp << "\n";
440+
for (Operation &op1 : moduleOp.getOps()) {
441+
if (isa<tt::FuncOp>(op1)) {
442+
for (Operation &op2 : cast<tt::FuncOp>(op1).getOps()) {
443+
if (failed(verify(&op2))) {
444+
llvm::errs() << "op2: " << op2 << "\n";
445+
llvm::errs() << "Operation verification failed.\n";
446+
assert(false);
426447
}
427448
}
428449
}
429-
llvm::errs() << "Module verification failed.\n";
430-
llvm::errs() << "mod: " << moduleOp << "\n";
431-
assert(false);
432450
}
433-
llvm::errs() << "Module verified.\n";
434451
}
435452
}
436453
};

0 commit comments

Comments
 (0)