@@ -34,24 +34,21 @@ static Operation *dropMask(Operation *op, bool maskVal) {
3434 TypeSwitch<Operation *>(op)
3535 .Case <tt::LoadOp>([&](auto loadOp) {
3636 if (maskVal) {
37- tt::LoadOp newLoadOp = builder.create <tt::LoadOp>(
38- loc, loadOp.getPtr (), loadOp.getCache (), loadOp.getEvict (),
37+ auto newLoadOp = builder.create <tt::LoadOp>(
38+ loc, loadOp.getPtr (), loadOp.getBoundaryCheck (),
39+ loadOp.getPadding (), loadOp.getCache (), loadOp.getEvict (),
3940 loadOp.getIsVolatile ());
4041 loadOp->replaceAllUsesWith (newLoadOp);
4142 } else {
42- Value other = loadOp. getOther ();
43- Operation *cstOp = builder.create <arith::ConstantOp>(loc, other );
43+ Operation *cstOp =
44+ builder.create <arith::ConstantOp>(loc, loadOp. getOther () );
4445 loadOp->replaceAllUsesWith (cstOp);
4546 }
4647 })
4748 .Case <arith::SelectOp>([&](auto selectOp) {
4849 selectOp->replaceAllUsesWith (
4950 (maskVal ? selectOp.getTrueValue () : selectOp.getFalseValue ())
5051 .getDefiningOp ());
51- })
52- .Default ([](auto ) {
53- return nullptr ;
54- llvm_unreachable (" Unexpected operation" );
5552 });
5653
5754 return nullptr ;
@@ -673,7 +670,7 @@ struct TritonIntelRemoveMasksBase
673670 void runOnOperation () final {
674671 ModuleOp moduleOp = getOperation ();
675672
676- // Remove masks if the are not necessary
673+ // Remove masks if they are not necessary.
677674 moduleOp->walk <WalkOrder::PreOrder>([&](Operation *op) {
678675 if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {
679676 // Nested loop aren't currently handled.
0 commit comments