|
1 | 1 | // RUN: triton-opt %s -split-input-file -triton-intel-fuse-reshape | FileCheck %s |
2 | 2 |
|
3 | | -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} { |
4 | | - // COM: tt.load -> tt.reshape -> tt.dot chain, not in a loop. |
5 | | - tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1: !tt.ptr<bf16>) { |
6 | | - %c0_i32 = arith.constant 0 : i32 |
7 | | - %c1_i32 = arith.constant 1 : i32 |
8 | | - %c2_i32 = arith.constant 2 : i32 |
9 | | - %c1_i64 = arith.constant 1 : i64 |
10 | | - %c2_i64 = arith.constant 2 : i64 |
11 | | - %c3_i64 = arith.constant 3 : i64 |
12 | | - %c1024_i64 = arith.constant 1024 : i64 |
13 | | - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> |
14 | | - %0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c3_i64, %c1024_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>> |
15 | | - %1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>> |
16 | | - %3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>> |
17 | | - %4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16> |
18 | | - %5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
19 | | - tt.return |
| 3 | +// COM: tt.load -> tt.reshape -> tt.dot chain, not in a loop. |
| 4 | +tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1: !tt.ptr<bf16>) { |
| 5 | + %c0_i32 = arith.constant 0 : i32 |
| 6 | + %c1_i32 = arith.constant 1 : i32 |
| 7 | + %c2_i32 = arith.constant 2 : i32 |
| 8 | + %c1_i64 = arith.constant 1 : i64 |
| 9 | + %c2_i64 = arith.constant 2 : i64 |
| 10 | + %c3_i64 = arith.constant 3 : i64 |
| 11 | + %c1024_i64 = arith.constant 1024 : i64 |
| 12 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> |
| 13 | + %0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c3_i64, %c1024_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>> |
| 14 | + %1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>> |
| 15 | + %3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>> |
| 16 | + %4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16> |
| 17 | + %5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 18 | + tt.return |
| 19 | +} |
| 20 | +// CHECK-LABEL: fuseLoadWithReshape1 |
| 21 | +// CHECK-NOT: tt.reshape |
| 22 | +// CHECK: [[MUL1:%.*]] = arith.muli %c3_i64, %c2_i64 : i64 |
| 23 | +// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64 |
| 24 | +// CHECK: [[TRUNC:%.*]] = arith.trunci %c3_i64 : i64 to i32 |
| 25 | +// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c2_i32 : i32 |
| 26 | +// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32 |
| 27 | +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, [[ADD1]]], [%c1024_i64, %c1_i64], [%c1_i32, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>> |
| 28 | +// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> |
| 29 | +// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 30 | + |
| 31 | +// ----- |
| 32 | + |
| 33 | +// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop. |
| 34 | +// COM: where the 'make_tensor_ptr' result is not loop carried. |
| 35 | +tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) { |
| 36 | + %c0_i32 = arith.constant 0 : i32 |
| 37 | + %c1_i32 = arith.constant 1 : i32 |
| 38 | + %c1_i64 = arith.constant 1 : i64 |
| 39 | + %c32_i32 = arith.constant 32 : i32 |
| 40 | + %c1024_i32 = arith.constant 1024 : i32 |
| 41 | + %c512_i64 = arith.constant 512 : i64 |
| 42 | + %c1024_i64 = arith.constant 1024 : i64 |
| 43 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> |
| 44 | + %0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c1_i64], [%c512_i64, %c1_i64, %c1024_i64], [%c1_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>> |
| 45 | + %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 { |
| 46 | + %1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> |
| 47 | + %3 = tt.load %0 {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x256x32xbf16>> |
| 48 | + %2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16> |
| 49 | + %4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 50 | + %5 = arith.addi %arg5, %c32_i32 : i32 |
| 51 | + scf.yield %4, %5 : tensor<256x256xf32>, i32 |
20 | 52 | } |
21 | | - // CHECK-LABEL: fuseLoadWithReshape1 |
22 | | - // CHECK-NOT: tt.reshape |
23 | | - // CHECK: [[TRUNC:%.*]] = arith.trunci %c3_i64 : i64 to i32 |
24 | | - // CHECK: [[MUL:%.*]] = arith.muli [[TRUNC]], %c2_i32 : i32 |
25 | | - // CHECK: [[ADD:%.*]] = arith.addi [[MUL]], %c0_i32 : i32 |
26 | | - // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c1_i32, [[ADD]]] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>> |
27 | | - // CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> |
28 | | - // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 53 | + tt.return |
29 | 54 | } |
| 55 | +// CHECK-LABEL: fuseLoadWithReshape2 |
| 56 | +// CHECK-NOT: tt.reshape |
| 57 | +// CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, %c512_i64 : i64 |
| 58 | +// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64 |
| 59 | +// CHECK: [[TRUNC:%.*]] = arith.trunci %c512_i64 : i64 to i32 |
| 60 | +// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c1_i32 : i32 |
| 61 | +// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c32_i32 : i32 |
| 62 | +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1_i64], [%c1_i64, %c1024_i64], [[[ADD2]], %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>> |
| 63 | +// CHECK: scf.for |
| 64 | +// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>> |
| 65 | +// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
30 | 66 |
|
31 | 67 | // ----- |
32 | 68 |
|
33 | | -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} { |
34 | | - // COM: tt.load -> tt.reshape -> tt.dot chain, in a loop. |
35 | | - // COM: where the 'make_tensor_ptr' result is not loop carried. |
36 | | - tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) { |
37 | | - %c0_i32 = arith.constant 0 : i32 |
38 | | - %c1_i32 = arith.constant 1 : i32 |
39 | | - %c1_i64 = arith.constant 1 : i64 |
40 | | - %c32_i32 = arith.constant 32 : i32 |
41 | | - %c1024_i32 = arith.constant 1024 : i32 |
42 | | - %c512_i64 = arith.constant 512 : i64 |
43 | | - %c1024_i64 = arith.constant 1024 : i64 |
44 | | - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32> |
45 | | - %0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c1_i64], [%c512_i64, %c1_i64, %c1024_i64], [%c1_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>> |
46 | | - %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 { |
47 | | - %1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>> |
48 | | - %3 = tt.load %0 {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x256x32xbf16>> |
49 | | - %2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16> |
50 | | - %4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
51 | | - %5 = arith.addi %arg5, %c32_i32 : i32 |
52 | | - scf.yield %4, %5 : tensor<256x256xf32>, i32 |
53 | | - } |
54 | | - tt.return |
| 69 | +// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop |
| 70 | +// COM: Where the 'make_tensor_ptr' result is loop carried. |
| 71 | +tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) { |
| 72 | + %c127_i32 = arith.constant 127 : i32 |
| 73 | + %c255_i32 = arith.constant 255 : i32 |
| 74 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32> |
| 75 | + %c32_i32 = arith.constant 32 : i32 |
| 76 | + %c128_i32 = arith.constant 128 : i32 |
| 77 | + %c0_i32 = arith.constant 0 : i32 |
| 78 | + %c1_i64 = arith.constant 1 : i64 |
| 79 | + %c256_i32 = arith.constant 256 : i32 |
| 80 | + %c4_i32 = arith.constant 4 : i32 |
| 81 | + %0 = tt.get_program_id x : i32 |
| 82 | + %1 = arith.addi %M, %c255_i32 : i32 |
| 83 | + %2 = arith.divsi %1, %c256_i32 : i32 |
| 84 | + %3 = arith.addi %N, %c127_i32 : i32 |
| 85 | + %4 = arith.divsi %3, %c128_i32 : i32 |
| 86 | + %5 = arith.muli %4, %c4_i32 : i32 |
| 87 | + %6 = arith.divsi %0, %5 : i32 |
| 88 | + %7 = arith.muli %6, %c4_i32 : i32 |
| 89 | + %8 = arith.subi %2, %7 : i32 |
| 90 | + %9 = arith.minsi %8, %c4_i32 : i32 |
| 91 | + %10 = arith.remsi %0, %5 : i32 |
| 92 | + %11 = arith.remsi %10, %9 : i32 |
| 93 | + %12 = arith.addi %7, %11 : i32 |
| 94 | + %13 = arith.divsi %10, %9 : i32 |
| 95 | + %14 = arith.muli %12, %c256_i32 : i32 |
| 96 | + %15 = arith.extsi %M : i32 to i64 |
| 97 | + %16 = arith.extsi %K : i32 to i64 |
| 98 | + %17 = arith.extsi %stride_am : i32 to i64 |
| 99 | + %18 = tt.make_tensor_ptr %a_ptr, [%c1_i64, %15, %16], [%c1_i64, %17, %c1_i64], [%c0_i32, %14, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x256x32xf32>> |
| 100 | + %19 = arith.muli %13, %c128_i32 : i32 |
| 101 | + %20 = arith.extsi %N : i32 to i64 |
| 102 | + %21 = arith.extsi %stride_bk : i32 to i64 |
| 103 | + %22 = tt.make_tensor_ptr %b_ptr, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x128xf32>> |
| 104 | + %accumulator:3 = scf.for %k = %c0_i32 to %K step %c32_i32 iter_args(%a_block_ptr = %18, %b_block_ptr = %22, %accumulator_0 = %cst) -> (!tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>) : i32 { |
| 105 | + %25 = tt.load %a_block_ptr {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x256x32xf32>> |
| 106 | + %26 = tt.reshape %25 : tensor<1x256x32xf32> -> tensor<256x32xf32> |
| 107 | + %27 = tt.load %b_block_ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x128xf32>> |
| 108 | + %28 = tt.dot %26, %27, %cst, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32> |
| 109 | + %29 = arith.addf %accumulator_0, %28 : tensor<256x128xf32> |
| 110 | + %30 = tt.advance %a_block_ptr, [%c0_i32, %c0_i32, %c32_i32] : <tensor<1x256x32xf32>> |
| 111 | + %31 = tt.advance %b_block_ptr, [%c32_i32, %c0_i32] : <tensor<32x128xf32>> |
| 112 | + scf.yield %30, %31, %29 : !tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32> |
55 | 113 | } |
56 | | - // CHECK-LABEL: fuseLoadWithReshape2 |
57 | | - // CHECK-NOT: tt.reshape |
58 | | - // CHECK: [[TRUNC:%.*]] = arith.trunci %c512_i64 : i64 to i32 |
59 | | - // CHECK: [[MUL:%.*]] = arith.muli [[TRUNC]], %c1_i32 : i32 |
60 | | - // CHECK: [[ADD:%.*]] = arith.addi [[MUL]], %c0_i32 : i32 |
61 | | - // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c32_i32, [[ADD]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>> |
62 | | - // CHECK: scf.for |
63 | | - // CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>> |
64 | | - // CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32> |
| 114 | + %23 = arith.extsi %stride_cm : i32 to i64 |
| 115 | + %24 = tt.make_tensor_ptr %c_ptr, [%15, %20], [%23, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x128xf32>> |
| 116 | + tt.store %24, %accumulator#2 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x128xf32>> |
| 117 | + tt.return |
65 | 118 | } |
| 119 | +// CHECK-LABEL: test_matmul |
| 120 | +// CHECK-NOT: tt.reshape |
| 121 | +// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, %c1_i64 : i64 |
| 122 | +// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %16 : i64 |
| 123 | +// CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32 |
| 124 | +// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c0_i32 : i32 |
| 125 | +// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32 |
| 126 | +// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [%15, [[ADD1]]], [%17, %c1_i64], [%14, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<256x32xf32>> |
| 127 | +// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]] |
| 128 | +// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>> |
| 129 | +// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32> |
| 130 | +// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>> |
0 commit comments