|
| 1 | +// RUN: triton-opt %s -split-input-file -triton-intel-fuse-reshape | FileCheck %s |
| 2 | + |
| 3 | +// COM: tt.load -> tt.reshape -> tt.dot chain, not in a loop. |
| 4 | +tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1: !tt.ptr<bf16>) { |
| 5 | + %c0_i32 = arith.constant 0 : i32 |
| 6 | + %c1_i32 = arith.constant 1 : i32 |
| 7 | + %c2_i32 = arith.constant 2 : i32 |
| 8 | + %c1_i64 = arith.constant 1 : i64 |
| 9 | + %c4_i64 = arith.constant 4 : i64 |
| 10 | + %c64_i64 = arith.constant 4 : i64 |
| 11 | + %c1024_i64 = arith.constant 1024 : i64 |
| 12 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> |
| 13 | + %0 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64, %c1024_i64], [%c1024_i64, %c4_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>> |
| 14 | + %1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>> |
| 15 | + %3 = tt.load %0 {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x32x256xbf16>> |
| 16 | + %4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16> |
| 17 | + %5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 18 | + tt.return |
| 19 | +} |
| 20 | +// CHECK-LABEL: fuseLoadWithReshape1 |
| 21 | +// CHECK-NOT: tt.reshape |
| 22 | +// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c4_i64 : i64 |
| 23 | +// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64 |
| 24 | +// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c4_i64_0 : i64 |
| 25 | +// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32 |
| 26 | +// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32 |
| 27 | +// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32 |
| 28 | +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>> |
| 29 | +// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x256xbf16>> |
| 30 | +// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 31 | + |
| 32 | +// ----- |
| 33 | + |
| 34 | +// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop. |
| 35 | +// COM: where the 'make_tensor_ptr' result is not loop carried. |
| 36 | +tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) { |
| 37 | + %c0_i32 = arith.constant 0 : i32 |
| 38 | + %c32_i32 = arith.constant 32 : i32 |
| 39 | + %c1024_i32 = arith.constant 1024 : i32 |
| 40 | + %c32_i64 = arith.constant 32 : i64 |
| 41 | + %c1_i64 = arith.constant 1 : i64 |
| 42 | + %c512_i64 = arith.constant 512 : i64 |
| 43 | + %c1024_i64 = arith.constant 1024 : i64 |
| 44 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> |
| 45 | + %0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c32_i64], [%c1024_i64, %c1_i64, %c512_i64], [%c32_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>> |
| 46 | + %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 { |
| 47 | + %1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> |
| 48 | + %3 = tt.load %0 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<1x256x32xbf16>> |
| 49 | + %2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16> |
| 50 | + %4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 51 | + %5 = arith.addi %arg5, %c32_i32 : i32 |
| 52 | + scf.yield %4, %5 : tensor<256x256xf32>, i32 |
| 53 | + } |
| 54 | + tt.return |
| 55 | +} |
| 56 | +// CHECK-LABEL: fuseLoadWithReshape2 |
| 57 | +// CHECK-NOT: tt.reshape |
| 58 | +// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c512_i64 : i64 |
| 59 | +// CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, [[DIV]] : i64 |
| 60 | +// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c32_i64 : i64 |
| 61 | +// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32 |
| 62 | +// CHECK: [[MUL2:%.*]] = arith.muli %c32_i32, [[TRUNC]] : i32 |
| 63 | +// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32 |
| 64 | +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>> |
| 65 | +// CHECK: scf.for |
| 66 | +// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x32xbf16>> |
| 67 | +// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 68 | + |
| 69 | +// ----- |
| 70 | + |
| 71 | +// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop |
| 72 | +// COM: Where the 'make_tensor_ptr' result is loop carried. |
| 73 | +tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) { |
| 74 | + %c127_i32 = arith.constant 127 : i32 |
| 75 | + %c255_i32 = arith.constant 255 : i32 |
| 76 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32> |
| 77 | + %c32_i32 = arith.constant 32 : i32 |
| 78 | + %c128_i32 = arith.constant 128 : i32 |
| 79 | + %c0_i32 = arith.constant 0 : i32 |
| 80 | + %c1_i64 = arith.constant 1 : i64 |
| 81 | + %c256_i32 = arith.constant 256 : i32 |
| 82 | + %c4_i32 = arith.constant 4 : i32 |
| 83 | + %0 = tt.get_program_id x : i32 |
| 84 | + %1 = arith.addi %M, %c255_i32 : i32 |
| 85 | + %2 = arith.divsi %1, %c256_i32 : i32 |
| 86 | + %3 = arith.addi %N, %c127_i32 : i32 |
| 87 | + %4 = arith.divsi %3, %c128_i32 : i32 |
| 88 | + %5 = arith.muli %4, %c4_i32 : i32 |
| 89 | + %6 = arith.divsi %0, %5 : i32 |
| 90 | + %7 = arith.muli %6, %c4_i32 : i32 |
| 91 | + %8 = arith.subi %2, %7 : i32 |
| 92 | + %9 = arith.minsi %8, %c4_i32 : i32 |
| 93 | + %10 = arith.remsi %0, %5 : i32 |
| 94 | + %11 = arith.remsi %10, %9 : i32 |
| 95 | + %12 = arith.addi %7, %11 : i32 |
| 96 | + %13 = arith.divsi %10, %9 : i32 |
| 97 | + %14 = arith.muli %12, %c256_i32 : i32 |
| 98 | + %15 = arith.extsi %M : i32 to i64 |
| 99 | + %16 = arith.extsi %K : i32 to i64 |
| 100 | + %17 = arith.extsi %stride_am : i32 to i64 |
| 101 | + %18 = tt.make_tensor_ptr %a_ptr, [%c1_i64, %15, %16], [%c1_i64, %17, %c1_i64], [%c0_i32, %c128_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x256x32xf32>> |
| 102 | + %19 = arith.muli %13, %c128_i32 : i32 |
| 103 | + %20 = arith.extsi %N : i32 to i64 |
| 104 | + %21 = arith.extsi %stride_bk : i32 to i64 |
| 105 | + %22 = tt.make_tensor_ptr %b_ptr, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x128xf32>> |
| 106 | + %accumulator:3 = scf.for %k = %c0_i32 to %K step %c32_i32 iter_args(%a_block_ptr = %18, %b_block_ptr = %22, %accumulator_0 = %cst) -> (!tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>) : i32 { |
| 107 | + %25 = tt.load %a_block_ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x256x32xf32>> |
| 108 | + %26 = tt.reshape %25 : tensor<1x256x32xf32> -> tensor<256x32xf32> |
| 109 | + %27 = tt.load %b_block_ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x128xf32>> |
| 110 | + %28 = tt.dot %26, %27, %cst, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32> |
| 111 | + %29 = arith.addf %accumulator_0, %28 : tensor<256x128xf32> |
| 112 | + %30 = tt.advance %a_block_ptr, [%c0_i32, %c0_i32, %c32_i32] : <tensor<1x256x32xf32>> |
| 113 | + %31 = tt.advance %b_block_ptr, [%c32_i32, %c0_i32] : <tensor<32x128xf32>> |
| 114 | + scf.yield %30, %31, %29 : !tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32> |
| 115 | + } |
| 116 | + %23 = arith.extsi %stride_cm : i32 to i64 |
| 117 | + %24 = tt.make_tensor_ptr %c_ptr, [%15, %20], [%23, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x128xf32>> |
| 118 | + tt.store %24, %accumulator#2 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x128xf32>> |
| 119 | + tt.return |
| 120 | +} |
| 121 | +// CHECK-LABEL: fuseLoadWithReshape3 |
| 122 | +// CHECK-NOT: tt.reshape |
| 123 | +// CHECK: [[EXT_M:%.*]] = arith.extsi %arg3 : i32 to i64 |
| 124 | +// CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64 |
| 125 | +// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64 |
| 126 | +// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64 |
| 127 | +// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32 |
| 128 | +// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32 |
| 129 | +// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c128_i32 : i32 |
| 130 | +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>> |
| 131 | +// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]] |
| 132 | +// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<256x32xf32>> |
| 133 | +// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32> |
| 134 | +// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>> |
| 135 | + |
| 136 | +// ----- |
| 137 | + |
| 138 | +// COM: tt.load -> tt.reshape -> tt.dot chain, in 2 loops. |
| 139 | +// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation. |
| 140 | +tt.func public @fuseLoadWithReshape4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) { |
| 141 | + %c0_i32 = arith.constant 0 : i32 |
| 142 | + %c1_i32 = arith.constant 1 : i32 |
| 143 | + %c2_i32 = arith.constant 2 : i32 |
| 144 | + %c32_i32 = arith.constant 32 : i32 |
| 145 | + %c1_i64 = arith.constant 1 : i64 |
| 146 | + %c64_i64 = arith.constant 64 : i64 |
| 147 | + %c256_i64 = arith.constant 256 : i64 |
| 148 | + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32> |
| 149 | + %7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16>> |
| 150 | + %9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c256_i64, %c64_i64], [%c256_i64, %c64_i64, %c1_i64], [%c0_i32, %c1_i32, %c2_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x64xf16>> |
| 151 | + %10 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x32xf16>> |
| 152 | + %11 = tt.load %10 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16>> |
| 153 | + %res1:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 { |
| 154 | + %adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>> |
| 155 | + %load = tt.load %adv : !tt.ptr<tensor<1x32x64xf16>> |
| 156 | + %reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16> |
| 157 | + %dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> |
| 158 | + %add = arith.addi %arg4, %c32_i32 : i32 |
| 159 | + scf.yield %add : i32 |
| 160 | + } |
| 161 | + %res2:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 { |
| 162 | + %adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>> |
| 163 | + %load = tt.load %adv : !tt.ptr<tensor<1x32x64xf16>> |
| 164 | + %reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16> |
| 165 | + %dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> |
| 166 | + %add = arith.addi %arg4, %c32_i32 : i32 |
| 167 | + scf.yield %add : i32 |
| 168 | + } |
| 169 | + tt.return |
| 170 | +} |
| 171 | +// CHECK-LABEL: fuseLoadWithReshape4 |
| 172 | +// CHECK-NOT: tt.reshape |
| 173 | +// CHECK: [[DIV1:%.*]] = arith.divui %c256_i64, %c64_i64 : i64 |
| 174 | +// CHECK: [[MUL11:%.*]] = arith.muli %c1_i64, [[DIV1]] : i64 |
| 175 | +// CHECK: [[ADD11:%.*]] = arith.addi [[MUL11]], %c256_i64 : i64 |
| 176 | +// CHECK: [[TRUNC1:%.*]] = arith.trunci [[DIV1]] : i64 to i32 |
| 177 | +// CHECK: [[MUL21:%.*]] = arith.muli %c0_i32, [[TRUNC1]] : i32 |
| 178 | +// CHECK: [[ADD21:%.*]] = arith.addi [[MUL21]], %c1_i32 : i32 |
| 179 | +// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD11]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD21]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>> |
| 180 | +// CHECK: [[DIV2:%.*]] = arith.divui %c256_i64, %c64_i64 : i64 |
| 181 | +// CHECK: [[MUL12:%.*]] = arith.muli %c1_i64, [[DIV2]] : i64 |
| 182 | +// CHECK: [[ADD12:%.*]] = arith.addi [[MUL12]], %c256_i64 : i64 |
| 183 | +// CHECK: [[TRUNC2:%.*]] = arith.trunci [[DIV2]] : i64 to i32 |
| 184 | +// CHECK: [[MUL22:%.*]] = arith.muli %c0_i32, [[TRUNC2]] : i32 |
| 185 | +// CHECK: [[ADD22:%.*]] = arith.addi [[MUL22]], %c1_i32 : i32 |
| 186 | +// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD12]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD22]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>> |
| 187 | +// CHECK: scf.for |
| 188 | +// CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>> |
| 189 | +// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>> |
| 190 | +// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> |
| 191 | +// CHECK: scf.yield |
| 192 | +// CHECK: scf.for |
| 193 | +// CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>> |
| 194 | +// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>> |
| 195 | +// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> |
| 196 | +// CHECK: scf.yield |
0 commit comments