@@ -34,10 +34,12 @@ class MaskValidatorBase {
3434
3535 // Create the loop versioning condition based on the mask.
3636 virtual Value getVersioningCond (scf::ForOp &forOp, Value mask) const = 0;
37+
38+ virtual std::string getName () const = 0;
3739};
3840
3941// A mask validator which ensures that the mask can be reduced to the form:
40- // `END < N-i*END`.
42+ // `END-1 < N-i*END`
4143class CanonicalMaskValidator final : public MaskValidatorBase {
4244public:
4345 // This structure is used to store the information about a mask in canonical
@@ -47,7 +49,7 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
4749 unsigned END;
4850 };
4951
50- // Check whether the mask is equivalent to the form: `END < N-i*END`.
52+ // Check whether the mask is equivalent to the form: `END-1 < N-i*END`.
5153 virtual bool isValidMask (scf::ForOp &forOp, Value mask) const {
5254 assert (mask && " Expecting a valid mask" );
5355
@@ -102,8 +104,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
102104 // `(N+END-1)%END > 0 && N > END`.
103105 virtual Value getVersioningCond (scf::ForOp &forOp, Value mask) const {
104106 MaskInfo maskInfo = getMaskInfo (forOp, mask);
105- assert ( hasCanonicalUpperBound (forOp, maskInfo) &&
106- " Loop upper bound not in canonical form " ) ;
107+ if (! hasCanonicalUpperBound (forOp, maskInfo))
108+ return nullptr ;
107109
108110 OpBuilder builder (forOp);
109111 Location loc = forOp.getLoc ();
@@ -113,12 +115,13 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
113115
114116 // The loop UB is a constant.
115117 if (isa<arith::ConstantIntOp>(defOp)) {
116- int64_t valN =
118+ int64_t UB = cast<arith::ConstantIntOp>(defOp).value ();
119+ int64_t N =
117120 cast<arith::ConstantIntOp>(maskInfo.N .getDefiningOp ()).value ();
118- bool cond1 = ((valN + maskInfo.END - 1 ) % maskInfo. END ) > 0 ;
119- bool cond2 = valN > maskInfo. END ;
120- return builder.create <arith::ConstantIntOp>(
121- forOp. getLoc (), cond1 && cond2, builder.getI1Type ());
121+ unsigned END = maskInfo.END ;
122+ bool cond = UB <= ((N + END - 1 ) / END) - 1 ;
123+ return builder.create <arith::ConstantIntOp>(forOp. getLoc (), cond,
124+ builder.getI1Type ());
122125 }
123126
124127 auto divOp = cast<arith::DivSIOp>(defOp);
@@ -137,6 +140,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
137140 return builder.create <arith::AndIOp>(loc, cmp1, cmp2);
138141 }
139142
143+ virtual std::string getName () const { return " CanonicalMaskValidator" ; }
144+
140145 // Ensure the loop upper bound is in canonical form (N+END-1)/END.
141146 static bool hasCanonicalUpperBound (scf::ForOp &forOp,
142147 const MaskInfo &maskInfo) {
@@ -148,10 +153,10 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
148153 // If the loop UB is constant, use `MaskInfo` to determine whether the UB
149154 // was folded from a canonical form.
150155 if (isa<arith::ConstantIntOp>(defOp)) {
151- int64_t valN =
156+ int64_t UB = cast<arith::ConstantIntOp>(defOp).value ();
157+ int64_t N =
152158 cast<arith::ConstantIntOp>(maskInfo.N .getDefiningOp ()).value ();
153- return ((valN + maskInfo.END - 1 ) / maskInfo.END ) ==
154- cast<arith::ConstantIntOp>(defOp).value ();
159+ return UB == ((N + maskInfo.END - 1 ) / maskInfo.END ) - 1 ;
155160 }
156161
157162 if (!isa<arith::DivSIOp>(defOp))
@@ -279,6 +284,8 @@ class InvariantMaskValidator final : public MaskValidatorBase {
279284 llvm_unreachable (" Unexpected mask" );
280285 return {};
281286 }
287+
288+ virtual std::string getName () const { return " InvariantMaskValidator" ; }
282289};
283290
284291// Collects masked operations in a loop that satisfy the condition imposed by
@@ -300,7 +307,8 @@ template <typename MaskValidator> class MaskedOpsCollector {
300307 maskValidator.isValidMask (forOp, tt::intel::getFinalValue (mask))) {
301308 maskedOps.insert (op);
302309 LLVM_DEBUG (llvm::dbgs ()
303- << " Collected masked operation: " << *op << " \n " );
310+ << maskValidator.getName ()
311+ << " : collected masked operation: " << *op << " \n " );
304312 }
305313 }
306314 };
0 commit comments