Skip to content

Commit 1219b01

Browse files
tfruan2000guacamoleo
authored andcommitted
[BACKEND] Fix the combineSelectAndIf when the user of select in ifOp. (triton-lang#5031)
The CombineTensorSelectAndIf pass currently doesn’t work correctly **when the user of select is inside the scf.if block**. For example: ```mlir %select = arith.select %cond, %trueVal, %falseVal : i32 %if = scf.if %cond -> (i32) { %sub = arith.subi %select, %val1 : i32 scf.yield %sub : i32 } else { %mul = arith.muli %select, %val2 : i32 scf.yield %mul : i32 } use %select ``` In this case, dom.dominates(ifOp, user) will return true, but directly using replaceAllUsesWith would lead to incorrect replacement behavior. ```mlir // without this pr (the user in ifOp use the result of ifOp) %if:2 = scf.if %cond -> (i32, i32) { %sub = arith.subi %if#1, %val1 : i32 scf.yield %sub, %trueVal : i32, i32 } else { %mul = arith.muli %if#1, %val2 : i32 scf.yield %mul, %falseVal : i32, i32 } use %if#1 ``` To address this, we need to adjust the user’s operand based on the specific region it is in. ```mlir // with this pr (the user in ifOp be canonicaled first) %if:2 = scf.if %cond -> (i32, i32) { %sub = arith.subi %trueVal, %val1 : i32 scf.yield %sub, %trueVal : i32, i32 } else { %mul = arith.muli %falseVal, %val2 : i32 scf.yield %mul, %falseVal : i32, i32 } use %if#1 ```
1 parent 86ccfbc commit 1219b01

File tree

2 files changed

+107
-30
lines changed

2 files changed

+107
-30
lines changed

lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/IR/Dominance.h"
2+
#include "mlir/Support/LLVM.h"
23
#include "mlir/Transforms/Passes.h"
34
#include "triton/Analysis/Utility.h"
45
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -14,8 +15,52 @@ namespace gpu {
1415
#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF
1516
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
1617

17-
// Return true if the select could be merged into the If without breaking SSA
18-
// rules.
18+
/// The user of select maybe inside either the ThenRegion or ElseRegion of
19+
/// the scf.if. So, canonicalize user of select in scf.if first.
20+
static void canonicalizeSelectUsersInSCFIf(ModuleOp input) {
21+
llvm::MapVector<std::pair<Value, Value>, SmallVector<Operation *>>
22+
usersNeedreplaced;
23+
input.walk([&](arith::SelectOp selectOp) {
24+
auto *parentBlock = selectOp->getBlock();
25+
Value condition = selectOp.getOperand(0);
26+
Value trueVal = selectOp.getOperand(1);
27+
Value falseVal = selectOp.getOperand(2);
28+
Value resVal = selectOp.getResult();
29+
for (auto *condUser : condition.getUsers()) {
30+
if (!llvm::isa<scf::IfOp>(condUser))
31+
continue;
32+
scf::IfOp ifOp = llvm::cast<scf::IfOp>(condUser);
33+
for (auto *resUser : resVal.getUsers()) {
34+
if (ifOp->isProperAncestor(resUser)) {
35+
if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) !=
36+
nullptr) {
37+
// The user is inside the ThenRegion of the scf.if.
38+
usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back(
39+
resUser);
40+
} else {
41+
// The user is inside the ElseRegion of the scf.if.
42+
usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back(
43+
resUser);
44+
}
45+
}
46+
}
47+
}
48+
});
49+
50+
// Replace the operand of user.
51+
for (auto [replacedSrcAndDst, users] :
52+
llvm::make_early_inc_range(usersNeedreplaced)) {
53+
Value srcVal = replacedSrcAndDst.first;
54+
Value dstVal = replacedSrcAndDst.second;
55+
for (Operation *user : llvm::make_early_inc_range(users)) {
56+
srcVal.replaceUsesWithIf(
57+
dstVal, [&](OpOperand &use) { return use.getOwner() == user; });
58+
}
59+
}
60+
}
61+
62+
/// Return true if the select could be merged into the If without breaking SSA
63+
/// rules.
1964
static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp,
2065
DominanceInfo &dom) {
2166
// If needs to be dominated by the select.
@@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass
3883
void runOnOperation() override {
3984
MLIRContext *context = &getContext();
4085
ModuleOp m = getOperation();
41-
DominanceInfo dom(m);
86+
canonicalizeSelectUsersInSCFIf(m);
4287

4388
// Go over the arith.select ops, look if there is an if
4489
// with the same condition.
90+
DominanceInfo dom(m);
4591
llvm::MapVector<scf::IfOp, SmallVector<arith::SelectOp>> selectToIf;
4692
m.walk([&](arith::SelectOp selectOp) {
4793
// Look if there is an if in the same block, with the same condition.
Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,77 @@
11
// RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s
22

3-
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
4-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
53
// CHECK-LABEL: @select_if_combine
6-
tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr<f32>, #blocked>, %cnd: i1) attributes {noinline = false} {
7-
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00>
8-
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked>
9-
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00>
10-
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked>
11-
// CHECK-NOT: arith.select
12-
%sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked>
13-
// CHECK: %[[IF_RES:.*]] = scf.if
14-
scf.if %cnd {
15-
tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>, #blocked>
16-
// CHECK: scf.yield %[[CST0]]
17-
}
18-
// CHECK: else
19-
// CHECK: scf.yield %[[CST1]]
20-
// CHECK: tt.store %{{.*}}, %[[IF_RES]]
21-
tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>, #blocked>
22-
tt.return
4+
tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr<f32>>, %cnd: i1) {
5+
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00>
6+
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32>
7+
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00>
8+
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32>
9+
// CHECK-NOT: arith.select
10+
%sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32>
11+
// CHECK: %[[R:.+]] = scf.if %{{.*}}
12+
// CHECK: tt.store %{{.*}}, %{{.*}}
13+
// CHECK: scf.yield %[[CST0]]
14+
// CHECK: } else {
15+
// CHECK: scf.yield %[[CST1]]
16+
// CHECK: }
17+
scf.if %cnd {
18+
tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>>
2319
}
20+
// CHECK: tt.store %{{.*}}, %[[R]]
21+
tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>>
22+
tt.return
2423
}
2524

2625
// -----
27-
2826
// CHECK-LABEL: @if_multiple_sel
2927
tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){
30-
// CHECK-NOT: select
31-
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) {
32-
// CHECK: scf.yield {{.*}} : i32, i32, f32
33-
// CHECK: } else {
34-
// CHECK: scf.yield {{.*}} : i32, i32, f32
35-
// CHECK: }
36-
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32
28+
// CHECK-NOT: arith.select
3729
%0 = arith.select %arg0, %arg1, %arg2 : i32
3830
%1 = arith.select %arg0, %arg3, %arg4 : f32
31+
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) {
32+
// CHECK: scf.yield {{.*}} : i32, i32, f32
33+
// CHECK: } else {
34+
// CHECK: scf.yield {{.*}} : i32, i32, f32
35+
// CHECK: }
3936
%2 = scf.if %arg0 -> (i32) {
4037
%3 = arith.subi %arg1, %arg2 : i32
4138
scf.yield %3 : i32
4239
} else {
4340
scf.yield %arg1 : i32
4441
}
42+
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32
4543
tt.return %0, %1, %2 : i32, f32, i32
4644
}
45+
46+
// -----
47+
// CHECK-LABEL: tt.func @users_in_if(
48+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1
49+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32
50+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32
51+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32
52+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32
53+
tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) {
54+
// CHECK: %[[CST:.*]] = arith.constant 8 : i32
55+
%c8_i32 = arith.constant 8 : i32
56+
// CHECK-NOT: arith.select
57+
%0 = arith.select %arg0, %arg1, %arg2 : i32
58+
%1 = arith.select %arg0, %arg3, %arg4 : f32
59+
// CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) {
60+
// CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32
61+
// CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32
62+
// CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32
63+
// CHECK: } else {
64+
// CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32
65+
// CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32
66+
// CHECK: }
67+
%2:2 = scf.if %arg0 -> (i32, i32) {
68+
%3 = arith.muli %0, %arg2 : i32
69+
%4 = arith.addi %0, %c8_i32 : i32
70+
scf.yield %3, %4 : i32, i32
71+
} else {
72+
%3 = arith.subi %0, %c8_i32 : i32
73+
scf.yield %arg1, %3 : i32, i32
74+
}
75+
// CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32
76+
tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32
77+
}

0 commit comments

Comments
 (0)