Skip to content

Commit 7fb33e9

Browse files
committed
Handle select operations that get operand from operation returning multiple results
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 1a9d656 commit 7fb33e9

File tree

2 files changed

+77
-5
lines changed

2 files changed

+77
-5
lines changed

test/Triton/Intel/RemoveMasks/unnecessary-masks.mlir

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,79 @@ module {
5151
}
5252
tt.return
5353
}
54-
// CHECK: tt.func public @test1([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
54+
// CHECK: tt.func public @test1
5555
// CHECK: scf.for
5656
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<32x32x!tt.ptr<f16>>, tensor<32x32xi32>
5757
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_last : tensor<32x32x!tt.ptr<f16>>
5858
// CHECK: arith.extf [[LOAD]] : tensor<32x32xf16> to tensor<32x32xf32>
5959
// CHECK: [[ORI:%.+]] = arith.ori {{.*}} : tensor<32x32xi1>
6060
// CHECK: [[SEL:%.+]] = arith.select [[ORI]], {{.*}}, {{.*}} : tensor<32x32xi1>, tensor<32x32xf32>
6161
// CHECK: scf.yield [[SEL]] : tensor<32x32xf32>
62-
// CHECK: }
62+
63+
tt.func public @test2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
64+
%cst = arith.constant 0.000000e+00 : f32
65+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x8xf32>
66+
%c8_i32 = arith.constant 8 : i32
67+
%c128_i32 = arith.constant 128 : i32
68+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x8xf32>
69+
%cst_2 = arith.constant dense<16384> : tensor<64x1xi32>
70+
%cst_3 = arith.constant dense<128> : tensor<1x8xi32>
71+
%c0_i32 = arith.constant 0 : i32
72+
%cst_4 = arith.constant dense<128> : tensor<64x1xi32>
73+
%c64_i32 = arith.constant 64 : i32
74+
%0 = tt.get_program_id x : i32
75+
%1 = arith.muli %0, %c64_i32 : i32
76+
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
77+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
78+
%4 = tt.splat %1 : i32 -> tensor<64x1xi32>
79+
%5 = arith.addi %4, %3 : tensor<64x1xi32>
80+
%6 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32>
81+
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32>
82+
%8 = arith.remsi %5, %cst_4 : tensor<64x1xi32>
83+
%9 = arith.divsi %5, %cst_4 : tensor<64x1xi32>
84+
%10 = tt.broadcast %8 : tensor<64x1xi32> -> tensor<64x8xi32>
85+
%11 = arith.muli %9, %cst_2 : tensor<64x1xi32>
86+
%12 = tt.broadcast %11 : tensor<64x1xi32> -> tensor<64x8xi32>
87+
%13 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x8x!tt.ptr<f32>>
88+
%14:3 = scf.for %arg6 = %c0_i32 to %c128_i32 step %c8_i32 iter_args(%arg7 = %cst_1, %arg8 = %cst_1, %arg9 = %cst_1) -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>) : i32 {
89+
%25 = tt.splat %arg6 : i32 -> tensor<1x8xi32>
90+
%26 = arith.addi %25, %7 : tensor<1x8xi32>
91+
%27 = arith.cmpi slt, %26, %cst_3 : tensor<1x8xi32>
92+
%28 = arith.muli %26, %cst_3 : tensor<1x8xi32>
93+
%29 = tt.broadcast %28 : tensor<1x8xi32> -> tensor<64x8xi32>
94+
%30 = arith.addi %10, %29 : tensor<64x8xi32>
95+
%31 = arith.addi %30, %12 : tensor<64x8xi32>
96+
%32 = tt.addptr %13, %31 : tensor<64x8x!tt.ptr<f32>>, tensor<64x8xi32>
97+
%33 = tt.broadcast %27 : tensor<1x8xi1> -> tensor<64x8xi1>
98+
%34 = tt.load %32, %33, %cst_1 evictionPolicy = evict_first : tensor<64x8x!tt.ptr<f32>>
99+
%35 = arith.cmpi eq, %arg6, %c0_i32 : i32
100+
%36:3 = scf.if %35 -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>) {
101+
scf.yield %cst_1, %34, %cst_0 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
102+
} else {
103+
%40 = arith.subf %34, %arg7 : tensor<64x8xf32>
104+
%41 = arith.addf %arg9, %cst_0 : tensor<64x8xf32>
105+
%42 = arith.divf %40, %41 : tensor<64x8xf32>
106+
%43 = arith.addf %arg7, %42 : tensor<64x8xf32>
107+
%44 = arith.subf %34, %43 : tensor<64x8xf32>
108+
%45 = arith.mulf %40, %44 : tensor<64x8xf32>
109+
%46 = arith.addf %arg8, %45 : tensor<64x8xf32>
110+
scf.yield %46, %43, %41 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
111+
}
112+
%37 = arith.select %33, %36#1, %arg7 : tensor<64x8xi1>, tensor<64x8xf32>
113+
%38 = arith.select %33, %36#0, %arg8 : tensor<64x8xi1>, tensor<64x8xf32>
114+
%39 = arith.select %33, %36#2, %arg9 : tensor<64x8xi1>, tensor<64x8xf32>
115+
scf.yield %37, %38, %39 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
116+
}
117+
tt.return
118+
}
119+
// CHECK: tt.func public @test2
120+
// CHECK: scf.for
121+
// CHECK: [[PTR:%.+]] = tt.addptr {{.*}} : tensor<64x8x!tt.ptr<f32>>, tensor<64x8xi32>
122+
// CHECK: [[LOAD:%.+]] = tt.load [[PTR]] evictionPolicy = evict_first : tensor<64x8x!tt.ptr<f32>>
123+
// CHECK: [[IF_RES:%.+]]:3 = scf.if {{.*}} -> (tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>)
124+
// CHECK: scf.yield {{.*}}, [[LOAD]], {{.*}} : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
125+
// CHECK: else
126+
// CHECK-2: arith.subf [[LOAD]], {{.*}} : tensor<64x8xf32
127+
// CHECK: }
128+
// CHECK: scf.yield [[IF_RES]]#1, [[IF_RES]]#0, [[IF_RES]]#2 : tensor<64x8xf32>, tensor<64x8xf32>, tensor<64x8xf32>
63129
}

third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,15 @@ static Operation *dropMask(Operation *op, bool maskVal) {
4646
}
4747
})
4848
.Case<arith::SelectOp>([&](auto selectOp) {
49-
selectOp->replaceAllUsesWith(
50-
(maskVal ? selectOp.getTrueValue() : selectOp.getFalseValue())
51-
.getDefiningOp());
49+
Value origRes = selectOp.getResult();
50+
Value selectedVal =
51+
(maskVal ? selectOp.getTrueValue() : selectOp.getFalseValue());
52+
Value newRes = selectedVal;
53+
if (auto opResult = dyn_cast<OpResult>(selectedVal)) {
54+
Operation *defOp = opResult.getDefiningOp();
55+
newRes = defOp->getOpResult(opResult.getResultNumber());
56+
}
57+
origRes.replaceAllUsesWith(newRes);
5258
});
5359

5460
return nullptr;

0 commit comments

Comments
 (0)