Skip to content

Commit 7034506

Browse files
committed
Fix pre-commit
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent af1fd37 commit 7034506

File tree

1 file changed

+47
-47
lines changed

1 file changed

+47
-47
lines changed

test/Triton/Intel/RemoveMasks/unnecessary-masks.mlir

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,59 @@
22

33
module {
44
tt.func public @test1(%in_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %in_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
5-
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16>
6-
%c576_i32 = arith.constant 576 : i32
7-
%c0_i32 = arith.constant 0 : i32
8-
%cst_4 = arith.constant dense<9216> : tensor<32x1xi32>
9-
%cst_5 = arith.constant dense<16> : tensor<1x32xi32>
10-
%cst_6 = arith.constant dense<576> : tensor<1x32xi32>
11-
%cst_7 = arith.constant dense<0xFF800000> : tensor<32x32xf32>
12-
%cst_8 = arith.constant dense<16> : tensor<32x1xi32>
13-
%c32_i32 = arith.constant 32 : i32
14-
%0 = tt.get_program_id x : i32
15-
%1 = arith.muli %0, %c32_i32 : i32
16-
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
17-
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
18-
%4 = tt.splat %1 : i32 -> tensor<32x1xi32>
19-
%5 = arith.addi %4, %3 : tensor<32x1xi32>
20-
%6 = tt.expand_dims %2 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
21-
%7 = arith.remsi %5, %cst_8 : tensor<32x1xi32>
22-
%8 = arith.divsi %5, %cst_8 : tensor<32x1xi32>
23-
%9 = tt.splat %in_ptr1 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>>
24-
%10 = tt.addptr %9, %7 : tensor<32x1x!tt.ptr<f16>>, tensor<32x1xi32>
25-
%11 = tt.load %10 evictionPolicy = evict_last : tensor<32x1x!tt.ptr<f16>>
26-
%12 = arith.extf %11 : tensor<32x1xf16> to tensor<32x1xf32>
27-
%13 = tt.broadcast %7 : tensor<32x1xi32> -> tensor<32x32xi32>
28-
%14 = arith.muli %8, %cst_4 : tensor<32x1xi32>
29-
%15 = tt.broadcast %14 : tensor<32x1xi32> -> tensor<32x32xi32>
30-
%16 = tt.splat %in_ptr0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>>
31-
%17 = tt.broadcast %12 : tensor<32x1xf32> -> tensor<32x32xf32>
5+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16>
6+
%c576_i32 = arith.constant 576 : i32
7+
%c0_i32 = arith.constant 0 : i32
8+
%cst_4 = arith.constant dense<9216> : tensor<32x1xi32>
9+
%cst_5 = arith.constant dense<16> : tensor<1x32xi32>
10+
%cst_6 = arith.constant dense<576> : tensor<1x32xi32>
11+
%cst_7 = arith.constant dense<0xFF800000> : tensor<32x32xf32>
12+
%cst_8 = arith.constant dense<16> : tensor<32x1xi32>
13+
%c32_i32 = arith.constant 32 : i32
14+
%0 = tt.get_program_id x : i32
15+
%1 = arith.muli %0, %c32_i32 : i32
16+
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
17+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
18+
%4 = tt.splat %1 : i32 -> tensor<32x1xi32>
19+
%5 = arith.addi %4, %3 : tensor<32x1xi32>
20+
%6 = tt.expand_dims %2 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
21+
%7 = arith.remsi %5, %cst_8 : tensor<32x1xi32>
22+
%8 = arith.divsi %5, %cst_8 : tensor<32x1xi32>
23+
%9 = tt.splat %in_ptr1 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>>
24+
%10 = tt.addptr %9, %7 : tensor<32x1x!tt.ptr<f16>>, tensor<32x1xi32>
25+
%11 = tt.load %10 evictionPolicy = evict_last : tensor<32x1x!tt.ptr<f16>>
26+
%12 = arith.extf %11 : tensor<32x1xf16> to tensor<32x1xf32>
27+
%13 = tt.broadcast %7 : tensor<32x1xi32> -> tensor<32x32xi32>
28+
%14 = arith.muli %8, %cst_4 : tensor<32x1xi32>
29+
%15 = tt.broadcast %14 : tensor<32x1xi32> -> tensor<32x32xi32>
30+
%16 = tt.splat %in_ptr0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>>
31+
%17 = tt.broadcast %12 : tensor<32x1xf32> -> tensor<32x32xf32>
3232
%_tmp5 = scf.for %r0_offset = %c0_i32 to %c576_i32 step %c32_i32 iter_args(%_tmp5_9 = %cst_7) -> (tensor<32x32xf32>) : i32 {
33-
%44 = tt.splat %r0_offset : i32 -> tensor<1x32xi32>
34-
%45 = arith.addi %44, %6 : tensor<1x32xi32>
35-
%46 = arith.cmpi slt, %45, %cst_6 : tensor<1x32xi32>
36-
%47 = arith.muli %45, %cst_5 : tensor<1x32xi32>
37-
%48 = tt.broadcast %47 : tensor<1x32xi32> -> tensor<32x32xi32>
38-
%49 = arith.addi %13, %48 : tensor<32x32xi32>
39-
%50 = arith.addi %49, %15 : tensor<32x32xi32>
40-
%51 = tt.addptr %16, %50 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
41-
%52 = tt.broadcast %46 : tensor<1x32xi1> -> tensor<32x32xi1>
42-
%53 = tt.load %51, %52, %cst evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
43-
%54 = arith.extf %53 : tensor<32x32xf16> to tensor<32x32xf32>
44-
%55 = arith.addf %54, %17 : tensor<32x32xf32>
45-
%mask = arith.cmpf ogt, %_tmp5_9, %55 : tensor<32x32xf32>
46-
%56 = arith.cmpf une, %_tmp5_9, %_tmp5_9 : tensor<32x32xf32>
47-
%mask_10 = arith.ori %mask, %56 : tensor<32x32xi1>
48-
%57 = arith.select %mask_10, %_tmp5_9, %55 : tensor<32x32xi1>, tensor<32x32xf32>
49-
%58 = arith.select %52, %57, %_tmp5_9 : tensor<32x32xi1>, tensor<32x32xf32>
50-
scf.yield %58 : tensor<32x32xf32>
51-
}
33+
%44 = tt.splat %r0_offset : i32 -> tensor<1x32xi32>
34+
%45 = arith.addi %44, %6 : tensor<1x32xi32>
35+
%46 = arith.cmpi slt, %45, %cst_6 : tensor<1x32xi32>
36+
%47 = arith.muli %45, %cst_5 : tensor<1x32xi32>
37+
%48 = tt.broadcast %47 : tensor<1x32xi32> -> tensor<32x32xi32>
38+
%49 = arith.addi %13, %48 : tensor<32x32xi32>
39+
%50 = arith.addi %49, %15 : tensor<32x32xi32>
40+
%51 = tt.addptr %16, %50 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
41+
%52 = tt.broadcast %46 : tensor<1x32xi1> -> tensor<32x32xi1>
42+
%53 = tt.load %51, %52, %cst evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
43+
%54 = arith.extf %53 : tensor<32x32xf16> to tensor<32x32xf32>
44+
%55 = arith.addf %54, %17 : tensor<32x32xf32>
45+
%mask = arith.cmpf ogt, %_tmp5_9, %55 : tensor<32x32xf32>
46+
%56 = arith.cmpf une, %_tmp5_9, %_tmp5_9 : tensor<32x32xf32>
47+
%mask_10 = arith.ori %mask, %56 : tensor<32x32xi1>
48+
%57 = arith.select %mask_10, %_tmp5_9, %55 : tensor<32x32xi1>, tensor<32x32xf32>
49+
%58 = arith.select %52, %57, %_tmp5_9 : tensor<32x32xi1>, tensor<32x32xf32>
50+
scf.yield %58 : tensor<32x32xf32>
51+
}
5252
tt.return
5353
}
5454
// CHECK: tt.func public @test1([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
5555
// CHECK: scf.for
5656
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
57-
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
57+
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
5858
// CHECK: arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32>
5959
// CHECK: [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1>
6060
// CHECK: [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32>

0 commit comments

Comments
 (0)