|
1 | 1 | // RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s |
2 | 2 |
|
3 | | -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> |
4 | | -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { |
5 | 3 | // CHECK-LABEL: @select_if_combine |
6 | | - tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr<f32>, #blocked>, %cnd: i1) attributes {noinline = false} { |
7 | | - // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> |
8 | | - %cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked> |
9 | | - // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> |
10 | | - %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked> |
11 | | - // CHECK-NOT: arith.select |
12 | | - %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked> |
13 | | - // CHECK: %[[IF_RES:.*]] = scf.if |
14 | | - scf.if %cnd { |
15 | | - tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>, #blocked> |
16 | | - // CHECK: scf.yield %[[CST0]] |
17 | | - } |
18 | | - // CHECK: else |
19 | | - // CHECK: scf.yield %[[CST1]] |
20 | | - // CHECK: tt.store %{{.*}}, %[[IF_RES]] |
21 | | - tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>, #blocked> |
22 | | - tt.return |
| 4 | +tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr<f32>>, %cnd: i1) { |
| 5 | + // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> |
| 6 | + %cst = arith.constant dense<0.000000e+00> : tensor<64xf32> |
| 7 | + // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> |
| 8 | + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32> |
| 9 | + // CHECK-NOT: arith.select |
| 10 | + %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32> |
| 11 | + // CHECK: %[[R:.+]] = scf.if %{{.*}} |
| 12 | + // CHECK: tt.store %{{.*}}, %{{.*}} |
| 13 | + // CHECK: scf.yield %[[CST0]] |
| 14 | + // CHECK: } else { |
| 15 | + // CHECK: scf.yield %[[CST1]] |
| 16 | + // CHECK: } |
| 17 | + scf.if %cnd { |
| 18 | + tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>> |
23 | 19 | } |
| 20 | + // CHECK: tt.store %{{.*}}, %[[R]] |
| 21 | + tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>> |
| 22 | + tt.return |
24 | 23 | } |
25 | 24 |
|
26 | 25 | // ----- |
27 | | - |
28 | 26 | // CHECK-LABEL: @if_multiple_sel |
29 | 27 | tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){ |
30 | | -// CHECK-NOT: select |
31 | | -// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { |
32 | | -// CHECK: scf.yield {{.*}} : i32, i32, f32 |
33 | | -// CHECK: } else { |
34 | | -// CHECK: scf.yield {{.*}} : i32, i32, f32 |
35 | | -// CHECK: } |
36 | | -// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 |
| 28 | + // CHECK-NOT: arith.select |
37 | 29 | %0 = arith.select %arg0, %arg1, %arg2 : i32 |
38 | 30 | %1 = arith.select %arg0, %arg3, %arg4 : f32 |
| 31 | + // CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { |
| 32 | + // CHECK: scf.yield {{.*}} : i32, i32, f32 |
| 33 | + // CHECK: } else { |
| 34 | + // CHECK: scf.yield {{.*}} : i32, i32, f32 |
| 35 | + // CHECK: } |
39 | 36 | %2 = scf.if %arg0 -> (i32) { |
40 | 37 | %3 = arith.subi %arg1, %arg2 : i32 |
41 | 38 | scf.yield %3 : i32 |
42 | 39 | } else { |
43 | 40 | scf.yield %arg1 : i32 |
44 | 41 | } |
| 42 | + // CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 |
45 | 43 | tt.return %0, %1, %2 : i32, f32, i32 |
46 | 44 | } |
| 45 | + |
| 46 | +// ----- |
| 47 | +// CHECK-LABEL: tt.func @users_in_if( |
| 48 | +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1 |
| 49 | +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32 |
| 50 | +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32 |
| 51 | +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32 |
| 52 | +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32 |
| 53 | +tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) { |
| 54 | + // CHECK: %[[CST:.*]] = arith.constant 8 : i32 |
| 55 | + %c8_i32 = arith.constant 8 : i32 |
| 56 | + // CHECK-NOT: arith.select |
| 57 | + %0 = arith.select %arg0, %arg1, %arg2 : i32 |
| 58 | + %1 = arith.select %arg0, %arg3, %arg4 : f32 |
| 59 | + // CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) { |
| 60 | + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32 |
| 61 | + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32 |
| 62 | + // CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32 |
| 63 | + // CHECK: } else { |
| 64 | + // CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32 |
| 65 | + // CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32 |
| 66 | + // CHECK: } |
| 67 | + %2:2 = scf.if %arg0 -> (i32, i32) { |
| 68 | + %3 = arith.muli %0, %arg2 : i32 |
| 69 | + %4 = arith.addi %0, %c8_i32 : i32 |
| 70 | + scf.yield %3, %4 : i32, i32 |
| 71 | + } else { |
| 72 | + %3 = arith.subi %0, %c8_i32 : i32 |
| 73 | + scf.yield %arg1, %3 : i32, i32 |
| 74 | + } |
| 75 | + // CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32 |
| 76 | + tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32 |
| 77 | +} |
0 commit comments