22#include " intel/include/Utils/Utility.h"
33#include " mlir/Dialect/Arith/IR/Arith.h"
44#include " mlir/Dialect/SCF/IR/SCF.h"
5+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
6+ #include " mlir/IR/OpDefinition.h"
57#include " mlir/IR/Verifier.h"
8+ #include " mlir/Interfaces/InferIntRangeInterface.h"
9+ #include " mlir/Support/LLVM.h"
610#include " triton/Dialect/Triton/IR/Dialect.h"
11+ #include " llvm/ADT/TypeSwitch.h"
12+ #include " llvm/IR/Instructions.h"
713#include " llvm/Support/Debug.h"
814#include " llvm/Support/raw_ostream.h"
15+ #include < optional>
916
1017#define DEBUG_TYPE " triton-intel-remove-masks"
1118
@@ -19,12 +26,41 @@ namespace mlir::triton::intel {
1926
2027namespace {
2128
29+ static Operation *dropMask (Operation *op, bool maskVal) {
30+ assert (op && " Expecting a valid operation" );
31+
32+ OpBuilder builder (op);
33+ Location loc = op->getLoc ();
34+ TypeSwitch<Operation *>(op)
35+ .Case <tt::LoadOp>([&](auto loadOp) {
36+ if (maskVal) {
37+ tt::LoadOp newLoadOp = builder.create <tt::LoadOp>(
38+ loc, loadOp.getPtr (), loadOp.getCache (), loadOp.getEvict (),
39+ loadOp.getIsVolatile ());
40+ loadOp->replaceAllUsesWith (newLoadOp);
41+ } else {
42+ Value other = loadOp.getOther ();
43+ Operation *cstOp = builder.create <arith::ConstantOp>(loc, other);
44+ loadOp->replaceAllUsesWith (cstOp);
45+ }
46+ })
47+ .Case <arith::SelectOp>([&](auto selectOp) {
48+ selectOp->replaceAllUsesWith (
49+ (maskVal ? selectOp.getTrueValue () : selectOp.getFalseValue ())
50+ .getDefiningOp ());
51+ })
52+ .Default ([](auto ) {
53+ return nullptr ;
54+ llvm_unreachable (" Unexpected operation" );
55+ });
56+
57+ return nullptr ;
58+ }
59+
2260// Abstract base class for mask validators.
2361// Mask validators are used to check whether a given mask has an expected form.
2462// Concrete subclasses provide a member function used to select masked
2563// operations that have a mask in a particular (e.g. desired) form.
26- // Furthermore concrete mask validators classes might also provide a member
27- // function
2864class MaskValidatorBase {
2965public:
3066 virtual ~MaskValidatorBase () = default ;
@@ -38,6 +74,112 @@ class MaskValidatorBase {
3874 virtual std::string getName () const = 0;
3975};
4076
77+ // A mask validator which ensures the mask is not necessary.
78+ class RemovableMaskValidator final : public MaskValidatorBase {
79+ public:
80+ virtual bool isValidMask (scf::ForOp &forOp, Value mask) const {
81+ Value finalVal = tt::intel::getFinalValue (mask);
82+ assert (finalVal && " Expecting a valid mask" );
83+
84+ // Ensure the loop range is known.
85+ std::optional<ConstantIntRanges> optRange = computeLoopIVRange (forOp);
86+ if (!optRange)
87+ return false ;
88+
89+ if (!finalVal.getDefiningOp () ||
90+ !isa<arith::CmpIOp>(finalVal.getDefiningOp ()))
91+ return false ;
92+
93+ auto cmpOp = cast<arith::CmpIOp>(finalVal.getDefiningOp ());
94+ arith::CmpIPredicate pred = cmpOp.getPredicate ();
95+ if (pred != arith::CmpIPredicate::slt)
96+ return false ;
97+
98+ Value lhs = tt::intel::getFinalValue (cmpOp.getLhs ());
99+ Value rhs = tt::intel::getFinalValue (cmpOp.getRhs ());
100+ Operation *lhsOp = tt::intel::getFinalValue (lhs).getDefiningOp ();
101+ Operation *rhsOp = tt::intel::getFinalValue (rhs).getDefiningOp ();
102+ if (!lhsOp || !rhsOp)
103+ return false ;
104+
105+ auto getIntConstantValue = [](Operation *op) -> std::optional<APInt> {
106+ DenseElementsAttr constAttr;
107+ if (matchPattern (op, m_Constant (&constAttr))) {
108+ auto attr = constAttr.getSplatValue <Attribute>();
109+ if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
110+ return intAttr.getValue ();
111+ }
112+ return std::nullopt ;
113+ };
114+
115+ // TODO: consider the case where the constant is lhs.
116+ std::optional<APInt> constIntVal = getIntConstantValue (rhsOp);
117+ if (!constIntVal)
118+ return false ;
119+
120+ if (auto addOp = dyn_cast<arith::AddIOp>(lhsOp)) {
121+ Value lhs = tt::intel::getFinalValue (addOp.getLhs ());
122+ Value rhs = tt::intel::getFinalValue (addOp.getRhs ());
123+ if (lhs != forOp.getSingleInductionVar ())
124+ return false ;
125+ if (auto makeRangeOp = dyn_cast<tt::MakeRangeOp>(rhs.getDefiningOp ())) {
126+ APInt maxIV = (*optRange).smax ();
127+ int64_t maxVal = maxIV.getSExtValue () + makeRangeOp.getEnd ();
128+
129+ auto registerMaskValue = [this ](Value mask, bool maskVal) {
130+ for (Operation *user : mask.getUsers ())
131+ opToMaskValue.insert ({user, maskVal});
132+ };
133+
134+ if (maxVal <= constIntVal->getSExtValue ()) {
135+ registerMaskValue (mask, true );
136+ return true ;
137+ }
138+ APInt minIV = (*optRange).smin ();
139+ int64_t minVal = minIV.getSExtValue () + makeRangeOp.getStart ();
140+ if (minVal >= constIntVal->getSExtValue ()) {
141+ registerMaskValue (mask, false );
142+ return true ;
143+ }
144+ }
145+ }
146+
147+ return false ;
148+ }
149+
150+ virtual Value getVersioningCond (scf::ForOp &forOp, Value mask) const {
151+ return {};
152+ }
153+
154+ virtual std::string getName () const { return " RemovableMaskValidator" ; }
155+
156+ bool getMaskValue (Operation *op) const {
157+ assert (opToMaskValue.find (op) != opToMaskValue.end () && " mask not present" );
158+ return opToMaskValue[op];
159+ }
160+
161+ private:
162+ mutable std::map<Operation *, bool > opToMaskValue;
163+
164+ std::optional<ConstantIntRanges> computeLoopIVRange (scf::ForOp forOp) const {
165+ if (!forOp.getSingleInductionVar ())
166+ return std::nullopt ;
167+
168+ if (std::optional<APInt> tripCount = forOp.getStaticTripCount ()) {
169+ OpFoldResult lb = *forOp.getSingleLowerBound ();
170+ OpFoldResult step = *forOp.getSingleStep ();
171+ int64_t lbVal = *getConstantIntValue (lb);
172+ int64_t stepVal = *getConstantIntValue (step);
173+ int64_t lastIVVal = stepVal * (tripCount->getSExtValue () - 1 );
174+ llvm::APInt start (64 , lbVal, true );
175+ llvm::APInt end (64 , lastIVVal, true );
176+ return ConstantIntRanges::range (start, end, true );
177+ }
178+
179+ return std::nullopt ;
180+ }
181+ };
182+
41183// A mask validator which ensures that the mask can be reduced to the form:
42184// `END-1 < N-i*END`
43185class CanonicalMaskValidator final : public MaskValidatorBase {
@@ -51,12 +193,14 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
51193
52194 // Check whether the mask is equivalent to the form: `END-1 < N-i*END`.
53195 virtual bool isValidMask (scf::ForOp &forOp, Value mask) const {
54- assert (mask && " Expecting a valid mask" );
196+ Value finalVal = tt::intel::getFinalValue (mask);
197+ assert (finalVal && " Expecting a valid mask" );
55198
56- if (!mask.getDefiningOp () || !isa<arith::CmpIOp>(mask.getDefiningOp ()))
199+ if (!finalVal.getDefiningOp () ||
200+ !isa<arith::CmpIOp>(finalVal.getDefiningOp ()))
57201 return false ;
58202
59- auto cmpOp = cast<arith::CmpIOp>(mask .getDefiningOp ());
203+ auto cmpOp = cast<arith::CmpIOp>(finalVal .getDefiningOp ());
60204 arith::CmpIPredicate pred = cmpOp.getPredicate ();
61205 if (pred != arith::CmpIPredicate::slt)
62206 return false ;
@@ -103,7 +247,10 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
103247 // `(N+END-1)/END` (possibly folded), the versioning condition will be:
104248 // `(N+END-1)%END > 0 && N > END`.
105249 virtual Value getVersioningCond (scf::ForOp &forOp, Value mask) const {
106- MaskInfo maskInfo = getMaskInfo (forOp, mask);
250+ Value finalVal = tt::intel::getFinalValue (mask);
251+ assert (finalVal && " Expecting a valid mask" );
252+
253+ MaskInfo maskInfo = getMaskInfo (forOp, finalVal);
107254 if (!hasCanonicalUpperBound (forOp, maskInfo))
108255 return nullptr ;
109256
@@ -203,11 +350,14 @@ class InvariantMaskValidator final : public MaskValidatorBase {
203350 // - [0..END] < splat(N)
204351 // - splat(N) < [0..END]
205352 virtual bool isValidMask (scf::ForOp &forOp, Value mask) const {
206- assert (mask && " Expecting a valid mask" );
207- if (!mask.getDefiningOp () || !isa<arith::CmpIOp>(mask.getDefiningOp ()))
353+ Value finalVal = tt::intel::getFinalValue (mask);
354+ assert (finalVal && " Expecting a valid mask" );
355+
356+ if (!finalVal.getDefiningOp () ||
357+ !isa<arith::CmpIOp>(finalVal.getDefiningOp ()))
208358 return false ;
209359
210- auto cmpOp = cast<arith::CmpIOp>(mask .getDefiningOp ());
360+ auto cmpOp = cast<arith::CmpIOp>(finalVal .getDefiningOp ());
211361 arith::CmpIPredicate pred = cmpOp.getPredicate ();
212362 if (pred != arith::CmpIPredicate::slt)
213363 return false ;
@@ -242,7 +392,7 @@ class InvariantMaskValidator final : public MaskValidatorBase {
242392 }
243393
244394 return false ;
245- }
395+ } // namespace
246396
247397 virtual Value getVersioningCond (scf::ForOp &forOp, Value mask) const {
248398 assert (isValidMask (forOp, mask) && " Invalid mask" );
@@ -301,11 +451,8 @@ template <typename MaskValidator> class MaskedOpsCollector {
301451 bool collectMaskedOps () {
302452 auto collectMaskedOps = [&](auto ops, MaskedOperations &maskedOps) {
303453 for (Operation *op : ops) {
304- Value mask = isa<tt::LoadOp>(op) ? cast<tt::LoadOp>(op).getMask ()
305- : isa<tt::StoreOp>(op) ? cast<tt::StoreOp>(op).getMask ()
306- : nullptr ;
307- if (mask &&
308- maskValidator.isValidMask (forOp, tt::intel::getFinalValue (mask))) {
454+ Value mask = getMask (op);
455+ if (mask && maskValidator.isValidMask (forOp, mask)) {
309456 maskedOps.insert (op);
310457 LLVM_DEBUG (llvm::dbgs ()
311458 << maskValidator.getName ()
@@ -316,12 +463,23 @@ template <typename MaskValidator> class MaskedOpsCollector {
316463
317464 collectMaskedOps (forOp.getOps <tt::LoadOp>(), maskedOps);
318465 collectMaskedOps (forOp.getOps <tt::StoreOp>(), maskedOps);
466+ collectMaskedOps (forOp.getOps <arith::SelectOp>(), maskedOps);
319467 return maskedOps.size ();
320468 }
321469
322470 const MaskedOperations &getMaskedOps () const { return maskedOps; };
323471 const MaskValidator &getMaskValidator () const { return maskValidator; }
324472
473+ Value getMask (Operation *op) const {
474+ assert (op && " Expecting a valid operation" );
475+ return TypeSwitch<Operation *, Value>(op)
476+ .Case <tt::LoadOp, tt::StoreOp>(
477+ [](auto maskedOp) { return maskedOp.getMask (); })
478+ .template Case <arith::SelectOp>(
479+ [](auto selectOp) { return selectOp.getCondition (); })
480+ .Default ([](auto ) { return nullptr ; });
481+ }
482+
325483private:
326484 scf::ForOp &forOp;
327485 MaskValidator &maskValidator;
@@ -515,6 +673,28 @@ struct TritonIntelRemoveMasksBase
515673 void runOnOperation () final {
516674 ModuleOp moduleOp = getOperation ();
517675
676+ // Remove masks if the are not necessary
677+ moduleOp->walk <WalkOrder::PreOrder>([&](Operation *op) {
678+ if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {
679+ // Nested loop aren't currently handled.
680+ if (forOp->template getParentOfType <scf::ForOp>())
681+ return WalkResult::advance ();
682+
683+ if (!forOp.getSingleInductionVar ())
684+ return WalkResult::advance ();
685+
686+ RemovableMaskValidator maskValidator;
687+ MaskedOpsCollector collector (forOp, maskValidator);
688+ if (collector.collectMaskedOps ()) {
689+ for (Operation *op : collector.getMaskedOps ()) {
690+ bool maskVal = maskValidator.getMaskValue (op);
691+ dropMask (op, maskVal);
692+ }
693+ }
694+ }
695+ return WalkResult::advance ();
696+ });
697+
518698 // Version loops containing masked operation in canonical form.
519699 moduleOp->walk <WalkOrder::PreOrder>([&](Operation *op) {
520700 if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) {
0 commit comments