Skip to content

Commit 77b8626

Browse files
authored
Increase test coverage for the RemoveMask pass (#3733)
Add unit tests for the `RemoveMask` transformation. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent f762612 commit 77b8626

File tree

6 files changed

+281
-50
lines changed

6 files changed

+281
-50
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s
2+
3+
module {
4+
tt.func public @test_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
5+
%c31_i32 = arith.constant 31 : i32
6+
%cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32>
7+
%c127_i32 = arith.constant 127 : i32
8+
%c63_i32 = arith.constant 63 : i32
9+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x128xf16>
10+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
11+
%c1_i32 = arith.constant 1 : i32
12+
%c0_i32 = arith.constant 0 : i32
13+
%cst_2 = arith.constant dense<32> : tensor<64x32xi32>
14+
%c32_i32 = arith.constant 32 : i32
15+
%c128_i32 = arith.constant 128 : i32
16+
%c64_i32 = arith.constant 64 : i32
17+
%c4_i32 = arith.constant 4 : i32
18+
%0 = tt.get_program_id x : i32
19+
%1 = arith.addi %arg3, %c63_i32 : i32
20+
%2 = arith.divsi %1, %c64_i32 : i32
21+
%3 = arith.addi %arg4, %c127_i32 : i32
22+
%4 = arith.divsi %3, %c128_i32 : i32
23+
%5 = arith.muli %4, %c4_i32 : i32
24+
%6 = arith.divsi %0, %5 : i32
25+
%7 = arith.muli %6, %c4_i32 : i32
26+
%8 = arith.subi %2, %7 : i32
27+
%9 = arith.minsi %8, %c4_i32 : i32
28+
%10 = arith.remsi %0, %5 : i32
29+
%11 = arith.remsi %10, %9 : i32
30+
%12 = arith.addi %7, %11 : i32
31+
%13 = arith.divsi %10, %9 : i32
32+
%14 = arith.muli %12, %c64_i32 : i32
33+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
34+
%16 = tt.splat %14 : i32 -> tensor<64xi32>
35+
%17 = arith.addi %16, %15 : tensor<64xi32>
36+
%18 = tt.splat %arg3 : i32 -> tensor<64xi32>
37+
%19 = arith.remsi %17, %18 : tensor<64xi32>
38+
%20 = arith.muli %13, %c128_i32 : i32
39+
%21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
40+
%22 = tt.splat %20 : i32 -> tensor<128xi32>
41+
%23 = arith.addi %22, %21 : tensor<128xi32>
42+
%24 = tt.splat %arg4 : i32 -> tensor<128xi32>
43+
%25 = arith.remsi %23, %24 : tensor<128xi32>
44+
%26 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
45+
%27 = tt.expand_dims %19 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
46+
%28 = tt.splat %arg6 : i32 -> tensor<64x1xi32>
47+
%29 = arith.muli %27, %28 : tensor<64x1xi32>
48+
%30 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
49+
%31 = tt.broadcast %29 : tensor<64x1xi32> -> tensor<64x32xi32>
50+
%32 = tt.broadcast %30 : tensor<1x32xi32> -> tensor<64x32xi32>
51+
%33 = arith.addi %31, %32 : tensor<64x32xi32>
52+
%34 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x32x!tt.ptr<f16>>
53+
%35 = tt.addptr %34, %33 : tensor<64x32x!tt.ptr<f16>>, tensor<64x32xi32>
54+
%36 = tt.expand_dims %26 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
55+
%37 = tt.splat %arg7 : i32 -> tensor<32x1xi32>
56+
%38 = arith.muli %36, %37 : tensor<32x1xi32>
57+
%39 = tt.expand_dims %25 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32>
58+
%40 = tt.broadcast %38 : tensor<32x1xi32> -> tensor<32x128xi32>
59+
%41 = tt.broadcast %39 : tensor<1x128xi32> -> tensor<32x128xi32>
60+
%42 = arith.addi %40, %41 : tensor<32x128xi32>
61+
%43 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>>
62+
%44 = tt.addptr %43, %42 : tensor<32x128x!tt.ptr<f16>>, tensor<32x128xi32>
63+
%45 = arith.addi %arg5, %c31_i32 : i32
64+
%46 = arith.divsi %45, %c32_i32 : i32
65+
%47 = arith.muli %arg7, %c32_i32 : i32
66+
%48 = tt.splat %47 : i32 -> tensor<32x128xi32>
67+
%49:3 = scf.for %arg9 = %c0_i32 to %46 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %35, %arg12 = %44) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>) : i32 {
68+
%67 = arith.muli %arg9, %c32_i32 : i32
69+
%68 = arith.subi %arg5, %67 : i32
70+
%69 = tt.splat %68 : i32 -> tensor<1x32xi32>
71+
%70 = arith.cmpi slt, %30, %69 : tensor<1x32xi32>
72+
%71 = tt.broadcast %70 : tensor<1x32xi1> -> tensor<64x32xi1>
73+
%72 = tt.load %arg11, %71, %cst_1 : tensor<64x32x!tt.ptr<f16>>
74+
%73 = tt.splat %68 : i32 -> tensor<32x1xi32>
75+
%74 = arith.cmpi slt, %36, %73 : tensor<32x1xi32>
76+
%75 = tt.broadcast %74 : tensor<32x1xi1> -> tensor<32x128xi1>
77+
%76 = tt.load %arg12, %75, %cst_0 : tensor<32x128x!tt.ptr<f16>>
78+
%77 = tt.dot %72, %76, %arg10, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x128xf16> -> tensor<64x128xf32>
79+
%78 = tt.addptr %arg11, %cst_2 : tensor<64x32x!tt.ptr<f16>>, tensor<64x32xi32>
80+
%79 = tt.addptr %arg12, %48 : tensor<32x128x!tt.ptr<f16>>, tensor<32x128xi32>
81+
scf.yield %77, %78, %79 : tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>
82+
}
83+
%50 = arith.truncf %49#0 : tensor<64x128xf32> to tensor<64x128xf16>
84+
tt.return
85+
}
86+
87+
// CHECK: tt.func public @test_kernel([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_3_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_4_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_5_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_6_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_7_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_8_:%.+]]: i32 {tt.divisibility = 16 : i32}) {
88+
// CHECK: [[CST_0_i32:%.+]] = arith.constant 0 : i32
89+
// CHECK: [[CST_32_i32:%.+]] = arith.constant 32 : i32
90+
// CHECK: [[REM:%.+]] = arith.remsi [[PARAM_5_]], [[CST_32_i32]] : i32
91+
// CHECK: [[CMP1:%.+]] = arith.cmpi eq, [[REM]], [[CST_0_i32]] : i32
92+
// CHECK: [[CMP2:%.+]] = arith.cmpi sgt, [[PARAM_5_]], [[CST_32_i32]] : i32
93+
// CHECK: [[VER_COND:%.+]] = arith.andi [[CMP1]], [[CMP2]] : i1
94+
// CHECK: [[LOOP_VER:%.+]] = scf.if [[VER_COND]] -> (tensor<64x128xf32>) {
95+
// CHECK: [[THEN_LOOP_RES:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg10:%.+]] = {{.*}}, [[VAR_arg11:%.+]] = {{.*}}, [[VAR_arg12:%.+]] = {{.*}}) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>) : i32 {
96+
// CHECK: [[LOAD_A1:%.+]] = tt.load [[VAR_arg11]] : tensor<64x32x!tt.ptr<f16>>
97+
// CHECK: [[LOAD_B2:%.+]] = tt.load [[VAR_arg12]] : tensor<32x128x!tt.ptr<f16>>
98+
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>
99+
// CHECK: }
100+
// CHECK: scf.yield [[THEN_LOOP_RES]]#0 : tensor<64x128xf32>
101+
// CHECK: } else {
102+
// CHECK: [[ELSE_LOOP_RES:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg10:%.+]] = {{.*}}, [[VAR_arg11:%.+]] = {{.*}}, [[VAR_arg12:%.+]] = {{.*}}) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>) : i32 {
103+
// CHECK: [[LOAD_A2:%.+]] = tt.load [[VAR_arg11]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr<f16>>
104+
// CHECK: [[LOAD_B2:%.+]] = tt.load [[VAR_arg12]], {{.*}}, {{.*}} : tensor<32x128x!tt.ptr<f16>>
105+
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>
106+
// CHECK: }
107+
// CHECK: scf.yield [[ELSE_LOOP_RES]]#0 : tensor<64x128xf32>
108+
// CHECK: }
109+
// CHECK: tt.return
110+
// CHECK: }
111+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s
2+
3+
module {
4+
// COM: test loop versioning for loads with invariant mask operations.
5+
// COM: masks in form [0..END] < splat(X)
6+
tt.func public @test_invariant_masks_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32, %arg3: i32) {
7+
%cst = arith.constant dense<0xFF800000> : tensor<1024xf32>
8+
%cst_0 = arith.constant dense<0xFF800000> : tensor<512xf32>
9+
%0 = tt.get_program_id x : i32
10+
%1 = tt.get_num_programs x : i32
11+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
12+
%3 = tt.splat %arg3 : i32 -> tensor<1024xi32>
13+
%4 = arith.cmpi slt, %2, %3 : tensor<1024xi32>
14+
%5 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
15+
%6 = tt.splat %arg3 : i32 -> tensor<512xi32>
16+
%7 = arith.cmpi slt, %5, %6 : tensor<512xi32>
17+
scf.for %arg6 = %0 to %arg2 step %1 : i32 {
18+
%8 = arith.muli %arg6, %arg1 : i32
19+
%9 = tt.addptr %arg0, %8 : !tt.ptr<f32>, i32
20+
%10 = tt.splat %9 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
21+
%11 = tt.splat %9 : !tt.ptr<f32> -> tensor<512x!tt.ptr<f32>>
22+
%12 = tt.addptr %10, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
23+
%13 = tt.addptr %11, %5 : tensor<512x!tt.ptr<f32>>, tensor<512xi32>
24+
%14 = tt.load %12, %4, %cst : tensor<1024x!tt.ptr<f32>>
25+
%15 = tt.load %13, %7, %cst_0 : tensor<512x!tt.ptr<f32>>
26+
}
27+
tt.return
28+
}
29+
30+
// CHECK: tt.func public @test_invariant_masks_1([[PARAM_0:%.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32, [[PARAM_3:%.+]]: i32) {
31+
// CHECK: [[VAR_2:%.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
32+
// CHECK: [[VAR_3:%.+]] = tt.splat [[PARAM_3]] : i32 -> tensor<1024xi32>
33+
// CHECK: [[VAR_4:%.+]] = arith.cmpi slt, [[VAR_2]], [[VAR_3]] : tensor<1024xi32>
34+
// CHECK: [[VAR_5:%.+]] = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
35+
// CHECK: [[VAR_6:%.+]] = tt.splat [[PARAM_3]] : i32 -> tensor<512xi32>
36+
// CHECK-DAG: [[VAR_7:%.+]] = arith.cmpi slt, [[VAR_5]], [[VAR_6]] : tensor<512xi32>
37+
// CHECK-DAG: [[CST_1023:%.+]] = arith.constant 1023 : i32
38+
// CHECK: [[VAR_8:%.+]] = arith.cmpi sgt, [[PARAM_3]], [[CST_1023]] : i32
39+
// CHECK: [[CST_511:%.+]] = arith.constant 511 : i32
40+
// CHECK: [[VAR_9:%.+]] = arith.cmpi sgt, [[PARAM_3]], [[CST_511]] : i32
41+
// CHECK: [[VAR_10:%.+]] = arith.andi [[VAR_8]], [[VAR_9]] : i1
42+
// CHECK: scf.if [[VAR_10]] {
43+
// CHECK: scf.for {{.+}} = {{.+}} to {{.+}} step {{.+}} : i32 {
44+
// CHECK-NOT: tt.load {{.*}}, {{.*}}, {{.*}}
45+
// CHECK: }
46+
// CHECK: } else {
47+
// CHECK: scf.for {{.+}} = {{.+}} to {{.+}} step {{.+}} : i32 {
48+
// CHECK-DAG: [[LOAD_A2:%.+]] = tt.load {{.*}}, [[VAR_4]], {{.*}} : tensor<1024x!tt.ptr<f32>>
49+
// CHECK-DAG: [[LOAD_B2:%.+]] = tt.load {{.*}}, [[VAR_7]], {{.*}} : tensor<512x!tt.ptr<f32>>
50+
// CHECK: }
51+
// CHECK: }
52+
// CHECK: tt.return
53+
// CHECK: }
54+
}
55+
56+
// -----
57+
58+
module {
59+
// COM: test loop versioning for loads with invariant mask operations.
60+
// COM: masks in form splat(X) < [0..END]
61+
tt.func public @test_invariant_masks_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32, %arg3: i32) {
62+
%cst = arith.constant dense<0xFF800000> : tensor<1024xf32>
63+
%cst_0 = arith.constant dense<0xFF800000> : tensor<512xf32>
64+
%0 = tt.get_program_id x : i32
65+
%1 = tt.get_num_programs x : i32
66+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
67+
%3 = tt.splat %arg3 : i32 -> tensor<1024xi32>
68+
%4 = arith.cmpi slt, %3, %2 : tensor<1024xi32>
69+
%5 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
70+
%6 = tt.splat %arg3 : i32 -> tensor<512xi32>
71+
%7 = arith.cmpi slt, %6, %5 : tensor<512xi32>
72+
scf.for %arg6 = %0 to %arg2 step %1 : i32 {
73+
%8 = arith.muli %arg6, %arg1 : i32
74+
%9 = tt.addptr %arg0, %8 : !tt.ptr<f32>, i32
75+
%10 = tt.splat %9 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
76+
%11 = tt.splat %9 : !tt.ptr<f32> -> tensor<512x!tt.ptr<f32>>
77+
%12 = tt.addptr %10, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
78+
%13 = tt.addptr %11, %5 : tensor<512x!tt.ptr<f32>>, tensor<512xi32>
79+
%14 = tt.load %12, %4, %cst : tensor<1024x!tt.ptr<f32>>
80+
%15 = tt.load %13, %7, %cst_0 : tensor<512x!tt.ptr<f32>>
81+
}
82+
tt.return
83+
}
84+
85+
// CHECK: tt.func public @test_invariant_masks_2([[PARAM_0:%.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32, [[PARAM_3:%.+]]: i32) {
86+
// CHECK: [[VAR_2:%.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
87+
// CHECK: [[VAR_3:%.+]] = tt.splat [[PARAM_3]] : i32 -> tensor<1024xi32>
88+
// CHECK: [[VAR_4:%.+]] = arith.cmpi slt, [[VAR_3]], [[VAR_2]] : tensor<1024xi32>
89+
// CHECK: [[VAR_5:%.+]] = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
90+
// CHECK: [[VAR_6:%.+]] = tt.splat [[PARAM_3]] : i32 -> tensor<512xi32>
91+
// CHECK-DAG: [[VAR_7:%.+]] = arith.cmpi slt, [[VAR_6]], [[VAR_5]] : tensor<512xi32>
92+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
93+
// CHECK-DAG: [[VAR_8:%.+]] = arith.cmpi slt, [[PARAM_3]], [[CST_0_i32]] : i32
94+
// CHECK-DAG: [[CST_0_1_i32:%.+]] = arith.constant 0 : i32
95+
// CHECK-DAG: [[VAR_9:%.+]] = arith.cmpi slt, [[PARAM_3]], [[CST_0_1_i32]] : i32
96+
// CHECK: [[VAR_10:%.+]] = arith.andi [[VAR_8]], [[VAR_9]] : i1
97+
// CHECK: scf.if [[VAR_10]] {
98+
// CHECK: scf.for {{.+}} = {{.+}} to {{.+}} step {{.+}} : i32 {
99+
// CHECK-NOT: tt.load {{.*}}, {{.*}}, {{.*}}
100+
// CHECK: }
101+
// CHECK: } else {
102+
// CHECK: scf.for {{.+}} = {{.+}} to {{.+}} step {{.+}} : i32 {
103+
// CHECK-DAG: [[LOAD_A2:%.+]] = tt.load {{.*}}, [[VAR_4]], {{.*}} : tensor<1024x!tt.ptr<f32>>
104+
// CHECK-DAG: [[LOAD_B2:%.+]] = tt.load {{.*}}, [[VAR_7]], {{.*}} : tensor<512x!tt.ptr<f32>>
105+
// CHECK: }
106+
// CHECK: }
107+
// CHECK: tt.return
108+
// CHECK: }
109+
}

third_party/intel/include/Utils/Utility.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
#ifndef TRITON_INTEL_UTILS_UTILITY_H
22
#define TRITON_INTEL_UTILS_UTILITY_H
33

4-
#include <mlir/IR/Value.h>
4+
#include "mlir/IR/Builders.h"
5+
#include "mlir/IR/Value.h"
56

67
namespace mlir::triton::intel {
78

9+
// Lookup for a integer constant with the given value and bitwidth in the
10+
// current block (before the builder insertion point). Return it if found,
11+
// otherwise create a new one.
12+
Value findOrCreateIntConstant(Location loc, int val, unsigned bitWidth,
13+
OpBuilder &builder);
14+
815
// This function folds the `op` operation and returns the constant value if it
916
// has successfully folded to a constant. Otherwise, it returns `std::nullopt`.
1017
std::optional<int64_t> getFoldedConstantValue(Operation *op);
@@ -13,7 +20,7 @@ std::optional<int64_t> getFoldedConstantValue(Operation *op);
1320
// expected.
1421
bool isConstant(Value val, int64_t expected);
1522

16-
mlir::Value getFinalValue(Value value);
23+
Value getFinalValue(Value value);
1724

1825
} // namespace mlir::triton::intel
1926

third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
102102

103103
OpBuilder builder(forOp);
104104
Location loc = forOp.getLoc();
105-
Value zero =
106-
builder.createOrFold<arith::ConstantIntOp>(loc, 0, lhs.getType());
105+
Value zero = tt::intel::findOrCreateIntConstant(
106+
loc, 0, lhs.getType().getIntOrFloatBitWidth(), builder);
107107
Value cmp1 = builder.create<arith::CmpIOp>(
108108
loc, arith::CmpIPredicate::eq,
109109
builder.create<arith::RemSIOp>(loc, lhs, rhs), zero);
@@ -204,18 +204,18 @@ class InvariantMaskValidator final : public MaskValidatorBase {
204204
return builder.createOrFold<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
205205
lhsVal, rhsVal);
206206

207-
// [0..END] < splat(N)
207+
// [0..END] < splat(N) -- generate versioning condition 'END-1 < N'.
208208
if (!rhs && isa<tt::MakeRangeOp>(lhs)) {
209209
[[maybe_unused]] auto rangeOp = cast<tt::MakeRangeOp>(lhs);
210210
assert(rangeOp.getStart() < rangeOp.getEnd() && "Invalid range");
211-
unsigned end = rangeOp.getEnd();
212-
auto cstOp = builder.createOrFold<arith::ConstantIntOp>(loc, end,
213-
rhsVal.getType());
211+
unsigned end = rangeOp.getEnd() - 1u;
212+
auto cstOp = tt::intel::findOrCreateIntConstant(
213+
loc, end, rhsVal.getType().getIntOrFloatBitWidth(), builder);
214214
return builder.createOrFold<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
215215
cstOp, rhsVal);
216216
}
217217

218-
// splat(N) < [0..END]
218+
// splat(N) < [0..END] -- generate versioning condition 'N < END'.
219219
if (!lhs && isa<tt::MakeRangeOp>(rhs)) {
220220
[[maybe_unused]] auto rangeOp = cast<tt::MakeRangeOp>(rhs);
221221
assert(rangeOp.getStart() < rangeOp.getEnd() && "Invalid range");

0 commit comments

Comments
 (0)