@@ -1129,40 +1129,47 @@ struct TritonRaiseBlockPointer
11291129 void dropMasks (ModuleOp moduleOp) const {
11301130 assert (IgnoreMasks && " Expecting 'IgnoreMask' flag to be set" );
11311131
1132+ SmallVector<Operation *> opsWithMask;
11321133 moduleOp->walk <WalkOrder::PreOrder>([&](Operation *op) {
11331134 TypeSwitch<Operation *>(op)
1134- .Case <tt::LoadOp>([&](auto loadOp) {
1135- if (loadOp.getMask ()) {
1136- loadOp->emitWarning (" TritonRaiseBlockPointer: ignoring mask" );
1137- OpBuilder builder (loadOp);
1138- auto newLoadOp = builder.create <tt::LoadOp>(
1139- loadOp.getLoc (), loadOp.getPtr (), loadOp.getBoundaryCheck (),
1140- loadOp.getPadding (), loadOp.getCache (), loadOp.getEvict (),
1141- loadOp.getIsVolatile ());
1142- loadOp->replaceAllUsesWith (newLoadOp);
1143- loadOp->erase ();
1144- }
1145- return WalkResult::advance ();
1146- })
1147- .Case <tt::StoreOp>([&](auto storeOp) {
1148- if (storeOp.getMask ()) {
1149- storeOp->emitWarning (" TritonRaiseBlockPointer: ignoring mask" );
1150- OpBuilder builder (storeOp);
1151- auto newStoreOp = builder.createOrFold <tt::StoreOp>(
1152- storeOp.getLoc (), storeOp.getPtr (), storeOp.getValue (),
1153- storeOp.getBoundaryCheck (), storeOp.getCache (),
1154- storeOp.getEvict ());
1155-
1156- storeOp->erase ();
1157- if (storeOp.getMask ().getUsers ().empty ())
1158- storeOp.getMask ().getDefiningOp ()->erase ();
1135+ .Case <tt::LoadOp, tt::StoreOp>([&](auto opWithMask) {
1136+ if (opWithMask.getMask ()) {
1137+ opsWithMask.push_back (opWithMask);
11591138 }
11601139 return WalkResult::advance ();
11611140 })
11621141 .Default ([&](auto ) { return WalkResult::advance (); });
11631142 });
11641143
1165- moduleOp.dump ();
1144+ for (Operation *op : opsWithMask) {
1145+ TypeSwitch<Operation *>(op)
1146+ .Case <tt::LoadOp>([&](auto loadOp) {
1147+ loadOp->emitWarning (" TritonRaiseBlockPointer: ignoring mask" );
1148+ OpBuilder builder (loadOp);
1149+ auto newLoadOp = builder.create <tt::LoadOp>(
1150+ loadOp.getLoc (), loadOp.getPtr (), loadOp.getBoundaryCheck (),
1151+ loadOp.getPadding (), loadOp.getCache (), loadOp.getEvict (),
1152+ loadOp.getIsVolatile ());
1153+ loadOp->replaceAllUsesWith (newLoadOp);
1154+ loadOp->erase ();
1155+ })
1156+ .Case <tt::StoreOp>([&](auto storeOp) {
1157+ storeOp->emitWarning (" TritonRaiseBlockPointer: ignoring mask" );
1158+ OpBuilder builder (storeOp);
1159+ auto newStoreOp = builder.createOrFold <tt::StoreOp>(
1160+ storeOp.getLoc (), storeOp.getPtr (), storeOp.getValue (),
1161+ storeOp.getBoundaryCheck (), storeOp.getCache (),
1162+ storeOp.getEvict ());
1163+
1164+ Operation *maskOpToErase = nullptr ;
1165+ if (storeOp.getMask ().hasOneUse ())
1166+ maskOpToErase = storeOp.getMask ().getDefiningOp ();
1167+
1168+ storeOp->erase ();
1169+ if (maskOpToErase)
1170+ maskOpToErase->erase ();
1171+ });
1172+ }
11661173 }
11671174
11681175 static void dump (const IRMapping &map) {
0 commit comments