Skip to content

Commit e07bc13

Browse files
authored
Fix "off by one" error in RemoveMask pass (#4194)
Fixes #4186, #4187. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent c788bfd commit e07bc13

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

test/Triton/Intel/RemoveMasks/loop-canonical-masks.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ module {
110110
// CHECK: }
111111

112112
tt.func public @test_kernel2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
113+
%c7_i32 = arith.constant 7 : i32
113114
%c8_i32 = arith.constant 8 : i32
114115
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32>
115116
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16>
@@ -164,7 +165,7 @@ module {
164165
%33 = arith.addi %31, %32 : tensor<64x256xi32>
165166
%34 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>>
166167
%35 = tt.addptr %34, %33 : tensor<64x256x!tt.ptr<f16>>, tensor<64x256xi32>
167-
%36:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %27, %arg6 = %35) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x256x!tt.ptr<f16>>) : i32 {
168+
%36:3 = scf.for %arg3 = %c0_i32 to %c7_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %27, %arg6 = %35) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x256x!tt.ptr<f16>>) : i32 {
168169
%51 = arith.muli %arg3, %c64_i32 : i32
169170
%52 = arith.subi %c512_i32, %51 : i32
170171
%53 = tt.splat %52 : i32 -> tensor<1x64xi32>

third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ class MaskValidatorBase {
3434

3535
// Create the loop versioning condition based on the mask.
3636
virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const = 0;
37+
38+
virtual std::string getName() const = 0;
3739
};
3840

3941
// A mask validator which ensures that the mask can be reduced to the form:
40-
// `END < N-i*END`.
42+
// `END-1 < N-i*END`
4143
class CanonicalMaskValidator final : public MaskValidatorBase {
4244
public:
4345
// This structure is used to store the information about a mask in canonical
@@ -47,7 +49,7 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
4749
unsigned END;
4850
};
4951

50-
// Check whether the mask is equivalent to the form: `END < N-i*END`.
52+
// Check whether the mask is equivalent to the form: `END-1 < N-i*END`.
5153
virtual bool isValidMask(scf::ForOp &forOp, Value mask) const {
5254
assert(mask && "Expecting a valid mask");
5355

@@ -102,8 +104,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
102104
// `(N+END-1)%END > 0 && N > END`.
103105
virtual Value getVersioningCond(scf::ForOp &forOp, Value mask) const {
104106
MaskInfo maskInfo = getMaskInfo(forOp, mask);
105-
assert(hasCanonicalUpperBound(forOp, maskInfo) &&
106-
"Loop upper bound not in canonical form");
107+
if (!hasCanonicalUpperBound(forOp, maskInfo))
108+
return nullptr;
107109

108110
OpBuilder builder(forOp);
109111
Location loc = forOp.getLoc();
@@ -113,12 +115,13 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
113115

114116
// The loop UB is a constant.
115117
if (isa<arith::ConstantIntOp>(defOp)) {
116-
int64_t valN =
118+
int64_t UB = cast<arith::ConstantIntOp>(defOp).value();
119+
int64_t N =
117120
cast<arith::ConstantIntOp>(maskInfo.N.getDefiningOp()).value();
118-
bool cond1 = ((valN + maskInfo.END - 1) % maskInfo.END) > 0;
119-
bool cond2 = valN > maskInfo.END;
120-
return builder.create<arith::ConstantIntOp>(
121-
forOp.getLoc(), cond1 && cond2, builder.getI1Type());
121+
unsigned END = maskInfo.END;
122+
bool cond = UB <= ((N + END - 1) / END) - 1;
123+
return builder.create<arith::ConstantIntOp>(forOp.getLoc(), cond,
124+
builder.getI1Type());
122125
}
123126

124127
auto divOp = cast<arith::DivSIOp>(defOp);
@@ -137,6 +140,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
137140
return builder.create<arith::AndIOp>(loc, cmp1, cmp2);
138141
}
139142

143+
virtual std::string getName() const { return "CanonicalMaskValidator"; }
144+
140145
// Ensure the loop upper bound is in canonical form (N+END-1)/END.
141146
static bool hasCanonicalUpperBound(scf::ForOp &forOp,
142147
const MaskInfo &maskInfo) {
@@ -148,10 +153,10 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
148153
// If the loop UB is constant, use `MaskInfo` to determine whether the UB
149154
// was folded from a canonical form.
150155
if (isa<arith::ConstantIntOp>(defOp)) {
151-
int64_t valN =
156+
int64_t UB = cast<arith::ConstantIntOp>(defOp).value();
157+
int64_t N =
152158
cast<arith::ConstantIntOp>(maskInfo.N.getDefiningOp()).value();
153-
return ((valN + maskInfo.END - 1) / maskInfo.END) ==
154-
cast<arith::ConstantIntOp>(defOp).value();
159+
return UB == ((N + maskInfo.END - 1) / maskInfo.END) - 1;
155160
}
156161

157162
if (!isa<arith::DivSIOp>(defOp))
@@ -279,6 +284,8 @@ class InvariantMaskValidator final : public MaskValidatorBase {
279284
llvm_unreachable("Unexpected mask");
280285
return {};
281286
}
287+
288+
virtual std::string getName() const { return "InvariantMaskValidator"; }
282289
};
283290

284291
// Collects masked operations in a loop that satisfy the condition imposed by
@@ -300,7 +307,8 @@ template <typename MaskValidator> class MaskedOpsCollector {
300307
maskValidator.isValidMask(forOp, tt::intel::getFinalValue(mask))) {
301308
maskedOps.insert(op);
302309
LLVM_DEBUG(llvm::dbgs()
303-
<< "Collected masked operation: " << *op << "\n");
310+
<< maskValidator.getName()
311+
<< ": collected masked operation: " << *op << "\n");
304312
}
305313
}
306314
};

0 commit comments

Comments
 (0)