Skip to content

Commit c5179c8

Browse files
[TritonRaiseBlockPointer] Fix non-deterministic failure (#3492)
1. Erasing operations while iterating operation would cause unexpected behavior. 2. Cannot access operations that are already erased. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 3e1642d commit c5179c8

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,40 +1129,47 @@ struct TritonRaiseBlockPointer
11291129
void dropMasks(ModuleOp moduleOp) const {
11301130
assert(IgnoreMasks && "Expecting 'IgnoreMask' flag to be set");
11311131

1132+
SmallVector<Operation *> opsWithMask;
11321133
moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
11331134
TypeSwitch<Operation *>(op)
1134-
.Case<tt::LoadOp>([&](auto loadOp) {
1135-
if (loadOp.getMask()) {
1136-
loadOp->emitWarning("TritonRaiseBlockPointer: ignoring mask");
1137-
OpBuilder builder(loadOp);
1138-
auto newLoadOp = builder.create<tt::LoadOp>(
1139-
loadOp.getLoc(), loadOp.getPtr(), loadOp.getBoundaryCheck(),
1140-
loadOp.getPadding(), loadOp.getCache(), loadOp.getEvict(),
1141-
loadOp.getIsVolatile());
1142-
loadOp->replaceAllUsesWith(newLoadOp);
1143-
loadOp->erase();
1144-
}
1145-
return WalkResult::advance();
1146-
})
1147-
.Case<tt::StoreOp>([&](auto storeOp) {
1148-
if (storeOp.getMask()) {
1149-
storeOp->emitWarning("TritonRaiseBlockPointer: ignoring mask");
1150-
OpBuilder builder(storeOp);
1151-
auto newStoreOp = builder.createOrFold<tt::StoreOp>(
1152-
storeOp.getLoc(), storeOp.getPtr(), storeOp.getValue(),
1153-
storeOp.getBoundaryCheck(), storeOp.getCache(),
1154-
storeOp.getEvict());
1155-
1156-
storeOp->erase();
1157-
if (storeOp.getMask().getUsers().empty())
1158-
storeOp.getMask().getDefiningOp()->erase();
1135+
.Case<tt::LoadOp, tt::StoreOp>([&](auto opWithMask) {
1136+
if (opWithMask.getMask()) {
1137+
opsWithMask.push_back(opWithMask);
11591138
}
11601139
return WalkResult::advance();
11611140
})
11621141
.Default([&](auto) { return WalkResult::advance(); });
11631142
});
11641143

1165-
moduleOp.dump();
1144+
for (Operation *op : opsWithMask) {
1145+
TypeSwitch<Operation *>(op)
1146+
.Case<tt::LoadOp>([&](auto loadOp) {
1147+
loadOp->emitWarning("TritonRaiseBlockPointer: ignoring mask");
1148+
OpBuilder builder(loadOp);
1149+
auto newLoadOp = builder.create<tt::LoadOp>(
1150+
loadOp.getLoc(), loadOp.getPtr(), loadOp.getBoundaryCheck(),
1151+
loadOp.getPadding(), loadOp.getCache(), loadOp.getEvict(),
1152+
loadOp.getIsVolatile());
1153+
loadOp->replaceAllUsesWith(newLoadOp);
1154+
loadOp->erase();
1155+
})
1156+
.Case<tt::StoreOp>([&](auto storeOp) {
1157+
storeOp->emitWarning("TritonRaiseBlockPointer: ignoring mask");
1158+
OpBuilder builder(storeOp);
1159+
auto newStoreOp = builder.createOrFold<tt::StoreOp>(
1160+
storeOp.getLoc(), storeOp.getPtr(), storeOp.getValue(),
1161+
storeOp.getBoundaryCheck(), storeOp.getCache(),
1162+
storeOp.getEvict());
1163+
1164+
Operation *maskOpToErase = nullptr;
1165+
if (storeOp.getMask().hasOneUse())
1166+
maskOpToErase = storeOp.getMask().getDefiningOp();
1167+
1168+
storeOp->erase();
1169+
if (maskOpToErase)
1170+
maskOpToErase->erase();
1171+
});
1172+
}
11661173
}
11671174

11681175
static void dump(const IRMapping &map) {

0 commit comments

Comments
 (0)