Skip to content

Commit 0e374bf

Browse files
authored
Remove unnecessary masks (#5384)
This PR extends the `RemoveMask` pass in order to consider mask on load and select operations that evaluate to true (or false) in the entire loop iteration space. This masked loads to be transformed into unmasked ones, and the mask condition may become dead if not used by other operations (therefore it may contribute to reduction of arithmetic complexity). --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent e427f29 commit 0e374bf

File tree

2 files changed

+327
-15
lines changed

2 files changed

+327
-15
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s
2+
3+
module {
4+
tt.func public @test1(%in_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %in_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
5+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16>
6+
%c576_i32 = arith.constant 576 : i32
7+
%c0_i32 = arith.constant 0 : i32
8+
%cst_4 = arith.constant dense<9216> : tensor<32x1xi32>
9+
%cst_5 = arith.constant dense<16> : tensor<1x32xi32>
10+
%cst_6 = arith.constant dense<576> : tensor<1x32xi32>
11+
%cst_7 = arith.constant dense<0xFF800000> : tensor<32x32xf32>
12+
%cst_8 = arith.constant dense<16> : tensor<32x1xi32>
13+
%c32_i32 = arith.constant 32 : i32
14+
%0 = tt.get_program_id x : i32
15+
%1 = arith.muli %0, %c32_i32 : i32
16+
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
17+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32>
18+
%4 = tt.splat %1 : i32 -> tensor<32x1xi32>
19+
%5 = arith.addi %4, %3 : tensor<32x1xi32>
20+
%6 = tt.expand_dims %2 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
21+
%7 = arith.remsi %5, %cst_8 : tensor<32x1xi32>
22+
%8 = arith.divsi %5, %cst_8 : tensor<32x1xi32>
23+
%9 = tt.splat %in_ptr1 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>>
24+
%10 = tt.addptr %9, %7 : tensor<32x1x!tt.ptr<f16>>, tensor<32x1xi32>
25+
%11 = tt.load %10 evictionPolicy = evict_last : tensor<32x1x!tt.ptr<f16>>
26+
%12 = arith.extf %11 : tensor<32x1xf16> to tensor<32x1xf32>
27+
%13 = tt.broadcast %7 : tensor<32x1xi32> -> tensor<32x32xi32>
28+
%14 = arith.muli %8, %cst_4 : tensor<32x1xi32>
29+
%15 = tt.broadcast %14 : tensor<32x1xi32> -> tensor<32x32xi32>
30+
%16 = tt.splat %in_ptr0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>>
31+
%17 = tt.broadcast %12 : tensor<32x1xf32> -> tensor<32x32xf32>
32+
%_tmp5 = scf.for %r0_offset = %c0_i32 to %c576_i32 step %c32_i32 iter_args(%_tmp5_9 = %cst_7) -> (tensor<32x32xf32>) : i32 {
33+
%44 = tt.splat %r0_offset : i32 -> tensor<1x32xi32>
34+
%45 = arith.addi %44, %6 : tensor<1x32xi32>
35+
%46 = arith.cmpi slt, %45, %cst_6 : tensor<1x32xi32>
36+
%47 = arith.muli %45, %cst_5 : tensor<1x32xi32>
37+
%48 = tt.broadcast %47 : tensor<1x32xi32> -> tensor<32x32xi32>
38+
%49 = arith.addi %13, %48 : tensor<32x32xi32>
39+
%50 = arith.addi %49, %15 : tensor<32x32xi32>
40+
%51 = tt.addptr %16, %50 : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
41+
%52 = tt.broadcast %46 : tensor<1x32xi1> -> tensor<32x32xi1>
42+
%53 = tt.load %51, %52, %cst evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
43+
%54 = arith.extf %53 : tensor<32x32xf16> to tensor<32x32xf32>
44+
%55 = arith.addf %54, %17 : tensor<32x32xf32>
45+
%mask = arith.cmpf ogt, %_tmp5_9, %55 : tensor<32x32xf32>
46+
%56 = arith.cmpf une, %_tmp5_9, %_tmp5_9 : tensor<32x32xf32>
47+
%mask_10 = arith.ori %mask, %56 : tensor<32x32xi1>
48+
%57 = arith.select %mask_10, %_tmp5_9, %55 : tensor<32x32xi1>, tensor<32x32xf32>
49+
%58 = arith.select %52, %57, %_tmp5_9 : tensor<32x32xi1>, tensor<32x32xf32>
50+
scf.yield %58 : tensor<32x32xf32>
51+
}
52+
tt.return
53+
}
54+
// CHECK: tt.func public @test1
55+
// CHECK: scf.for
56+
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
57+
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
58+
// CHECK: arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32>
59+
// CHECK: [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1>
60+
// CHECK: [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32>
61+
// CHECK: scf.yield [[SEL]] : tensor<32x32xf32>
62+
63+
tt.func public @test2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
64+
%cst = arith.constant 0.000000e+00 : f32
65+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x8xf32>
66+
%c8_i32 = arith.constant 8 : i32
67+
%c128_i32 = arith.constant 128 : i32
68+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x8xf32>
69+
%cst_2 = arith.constant dense<16384> : tensor<64x1xi32>
70+
%cst_3 = arith.constant dense<128> : tensor<1x8xi32>
71+
%c0_i32 = arith.constant 0 : i32
72+
%cst_4 = arith.constant dense<128> : tensor<64x1xi32>
73+
%c64_i32 = arith.constant 64 : i32
74+
%0 = tt.get_program_id x : i32
75+
%1 = arith.muli %0, %c64_i32 : i32
76+
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
77+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
78+
%4 = tt.splat %1 : i32 -> tensor<64x1xi32>
79+
%5 = arith.addi %4, %3 : tensor<64x1xi32>
80+
%6 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32>
81+
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32>
82+
%8 = arith.remsi %5, %cst_4 : tensor<64x1xi32>
83+
%9 = arith.divsi %5, %cst_4 : tensor<64x1xi32>
84+
%10 = tt.broadcast %8 : tensor<64x1xi32> -> tensor<64x8xi32>
85+
%11 = arith.muli %9, %cst_2 : tensor<64x1xi32>
86+
%12 = tt.broadcast %11 : tensor<64x1xi32> -> tensor<64x8xi32>
87+
%13 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x8x!tt.ptr<f32>>
88+
%14:3 = scf.for %arg6 = %c0_i32 to %c128_i32 step %c8_i32 iter_args(%arg7 = %cst_1, %arg8 = %cst_1, %arg9 = %cst_1) -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>) : i32 {
89+
%25 = tt.splat %arg6 : i32 -> tensor<1x8xi32>
90+
%26 = arith.addi %25, %7 : tensor<1x8xi32>
91+
%27 = arith.cmpi slt, %26, %cst_3 : tensor<1x8xi32>
92+
%28 = arith.muli %26, %cst_3 : tensor<1x8xi32>
93+
%29 = tt.broadcast %28 : tensor<1x8xi32> -> tensor<64x8xi32>
94+
%30 = arith.addi %10, %29 : tensor<64x8xi32>
95+
%31 = arith.addi %30, %12 : tensor<64x8xi32>
96+
%32 = tt.addptr %13, %31 : tensor<64x8x!tt.ptr<f32>>, tensor<64x8xi32>
97+
%33 = tt.broadcast %27 : tensor<1x8xi1> -> tensor<64x8xi1>
98+
%34 = tt.load %32, %33, %cst_1 evictionPolicy = evict_first : tensor<64x8x!tt.ptr<f32>>
99+
%35 = arith.cmpi eq, %arg6, %c0_i32 : i32
100+
%36:3 = scf.if %35 -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>) {
101+
scf.yield %cst_1, %34, %cst_0 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
102+
} else {
103+
%40 = arith.subf %34, %arg7 : tensor<64x8xf32>
104+
%41 = arith.addf %arg9, %cst_0 : tensor<64x8xf32>
105+
%42 = arith.divf %40, %41 : tensor<64x8xf32>
106+
%43 = arith.addf %arg7, %42 : tensor<64x8xf32>
107+
%44 = arith.subf %34, %43 : tensor<64x8xf32>
108+
%45 = arith.mulf %40, %44 : tensor<64x8xf32>
109+
%46 = arith.addf %arg8, %45 : tensor<64x8xf32>
110+
scf.yield %46, %43, %41 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
111+
}
112+
%37 = arith.select %33, %36#1, %arg7 : tensor<64x8xi1>, tensor<64x8xf32>
113+
%38 = arith.select %33, %36#0, %arg8 : tensor<64x8xi1>, tensor<64x8xf32>
114+
%39 = arith.select %33, %36#2, %arg9 : tensor<64x8xi1>, tensor<64x8xf32>
115+
scf.yield %37, %38, %39 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
116+
}
117+
tt.return
118+
}
119+
// CHECK: tt.func public @test2
120+
// CHECK: scf.for
121+
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<64x8x!tt.ptr<f32>>, tensor<64x8xi32>
122+
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_first : tensor<64x8x!tt.ptr<f32>>
123+
// CHECK: [[IF_RES:%.+]]:3 = scf.if {{.*}} -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>)
124+
// CHECK: scf.yield {{.*}}, [[LOAD]], {{.*}} : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
125+
// CHECK: else
126+
// CHECK-2: arith.subf [[LOAD]], {{.*}} : tensor<64x8xf32
127+
// CHECK: }
128+
// CHECK: scf.yield [[IF_RES]]#1, [[IF_RES]]#0, [[IF_RES]]#2 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
129+
}

0 commit comments

Comments
 (0)