Skip to content

Commit 26ca64e

Browse files
committed
Fix codegen
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 2453d59 commit 26ca64e

File tree

2 files changed

+129
-61
lines changed

2 files changed

+129
-61
lines changed
Lines changed: 121 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,130 @@
11
// RUN: triton-opt %s -split-input-file -triton-intel-fuse-reshape | FileCheck %s
22

3-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
4-
// COM: tt.load -> tt.reshape -> tt.dot chain, not in a loop.
5-
tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1: !tt.ptr<bf16>) {
6-
%c0_i32 = arith.constant 0 : i32
7-
%c1_i32 = arith.constant 1 : i32
8-
%c2_i32 = arith.constant 2 : i32
9-
%c1_i64 = arith.constant 1 : i64
10-
%c2_i64 = arith.constant 2 : i64
11-
%c3_i64 = arith.constant 3 : i64
12-
%c1024_i64 = arith.constant 1024 : i64
13-
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
14-
%0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c3_i64, %c1024_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
15-
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
16-
%3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>>
17-
%4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16>
18-
%5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
19-
tt.return
3+
// COM: tt.load -> tt.reshape -> tt.dot chain, not in a loop.
4+
tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1: !tt.ptr<bf16>) {
5+
%c0_i32 = arith.constant 0 : i32
6+
%c1_i32 = arith.constant 1 : i32
7+
%c2_i32 = arith.constant 2 : i32
8+
%c1_i64 = arith.constant 1 : i64
9+
%c2_i64 = arith.constant 2 : i64
10+
%c3_i64 = arith.constant 3 : i64
11+
%c1024_i64 = arith.constant 1024 : i64
12+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
13+
%0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c3_i64, %c1024_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
14+
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
15+
%3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>>
16+
%4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16>
17+
%5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
18+
tt.return
19+
}
20+
// CHECK-LABEL: fuseLoadWithReshape1
21+
// CHECK-NOT: tt.reshape
22+
// CHECK: [[MUL1:%.*]] = arith.muli %c3_i64, %c2_i64 : i64
23+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64
24+
// CHECK: [[TRUNC:%.*]] = arith.trunci %c3_i64 : i64 to i32
25+
// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c2_i32 : i32
26+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
27+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, [[ADD1]]], [%c1024_i64, %c1_i64], [%c1_i32, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
28+
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
29+
// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
30+
31+
// -----
32+
33+
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop.
34+
// COM: where the 'make_tensor_ptr' result is not loop carried.
35+
tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) {
36+
%c0_i32 = arith.constant 0 : i32
37+
%c1_i32 = arith.constant 1 : i32
38+
%c1_i64 = arith.constant 1 : i64
39+
%c32_i32 = arith.constant 32 : i32
40+
%c1024_i32 = arith.constant 1024 : i32
41+
%c512_i64 = arith.constant 512 : i64
42+
%c1024_i64 = arith.constant 1024 : i64
43+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
44+
%0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c1_i64], [%c512_i64, %c1_i64, %c1024_i64], [%c1_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>>
45+
%res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
46+
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
47+
%3 = tt.load %0 {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x256x32xbf16>>
48+
%2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16>
49+
%4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
50+
%5 = arith.addi %arg5, %c32_i32 : i32
51+
scf.yield %4, %5 : tensor<256x256xf32>, i32
2052
}
21-
// CHECK-LABEL: fuseLoadWithReshape1
22-
// CHECK-NOT: tt.reshape
23-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c3_i64 : i64 to i32
24-
// CHECK: [[MUL:%.*]] = arith.muli [[TRUNC]], %c2_i32 : i32
25-
// CHECK: [[ADD:%.*]] = arith.addi [[MUL]], %c0_i32 : i32
26-
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c1_i32, [[ADD]]] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
27-
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
28-
// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
53+
tt.return
2954
}
55+
// CHECK-LABEL: fuseLoadWithReshape2
56+
// CHECK-NOT: tt.reshape
57+
// CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, %c512_i64 : i64
58+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64
59+
// CHECK: [[TRUNC:%.*]] = arith.trunci %c512_i64 : i64 to i32
60+
// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c1_i32 : i32
61+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c32_i32 : i32
62+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1_i64], [%c1_i64, %c1024_i64], [[[ADD2]], %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
63+
// CHECK: scf.for
64+
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>>
65+
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
3066

3167
// -----
3268

33-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
34-
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop.
35-
// COM: where the 'make_tensor_ptr' result is not loop carried.
36-
tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) {
37-
%c0_i32 = arith.constant 0 : i32
38-
%c1_i32 = arith.constant 1 : i32
39-
%c1_i64 = arith.constant 1 : i64
40-
%c32_i32 = arith.constant 32 : i32
41-
%c1024_i32 = arith.constant 1024 : i32
42-
%c512_i64 = arith.constant 512 : i64
43-
%c1024_i64 = arith.constant 1024 : i64
44-
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
45-
%0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c1_i64], [%c512_i64, %c1_i64, %c1024_i64], [%c1_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>>
46-
%res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
47-
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
48-
%3 = tt.load %0 {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x256x32xbf16>>
49-
%2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16>
50-
%4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
51-
%5 = arith.addi %arg5, %c32_i32 : i32
52-
scf.yield %4, %5 : tensor<256x256xf32>, i32
53-
}
54-
tt.return
69+
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
70+
// COM: Where the 'make_tensor_ptr' result is loop carried.
71+
tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) {
72+
%c127_i32 = arith.constant 127 : i32
73+
%c255_i32 = arith.constant 255 : i32
74+
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32>
75+
%c32_i32 = arith.constant 32 : i32
76+
%c128_i32 = arith.constant 128 : i32
77+
%c0_i32 = arith.constant 0 : i32
78+
%c1_i64 = arith.constant 1 : i64
79+
%c256_i32 = arith.constant 256 : i32
80+
%c4_i32 = arith.constant 4 : i32
81+
%0 = tt.get_program_id x : i32
82+
%1 = arith.addi %M, %c255_i32 : i32
83+
%2 = arith.divsi %1, %c256_i32 : i32
84+
%3 = arith.addi %N, %c127_i32 : i32
85+
%4 = arith.divsi %3, %c128_i32 : i32
86+
%5 = arith.muli %4, %c4_i32 : i32
87+
%6 = arith.divsi %0, %5 : i32
88+
%7 = arith.muli %6, %c4_i32 : i32
89+
%8 = arith.subi %2, %7 : i32
90+
%9 = arith.minsi %8, %c4_i32 : i32
91+
%10 = arith.remsi %0, %5 : i32
92+
%11 = arith.remsi %10, %9 : i32
93+
%12 = arith.addi %7, %11 : i32
94+
%13 = arith.divsi %10, %9 : i32
95+
%14 = arith.muli %12, %c256_i32 : i32
96+
%15 = arith.extsi %M : i32 to i64
97+
%16 = arith.extsi %K : i32 to i64
98+
%17 = arith.extsi %stride_am : i32 to i64
99+
%18 = tt.make_tensor_ptr %a_ptr, [%c1_i64, %15, %16], [%c1_i64, %17, %c1_i64], [%c0_i32, %14, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x256x32xf32>>
100+
%19 = arith.muli %13, %c128_i32 : i32
101+
%20 = arith.extsi %N : i32 to i64
102+
%21 = arith.extsi %stride_bk : i32 to i64
103+
%22 = tt.make_tensor_ptr %b_ptr, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x128xf32>>
104+
%accumulator:3 = scf.for %k = %c0_i32 to %K step %c32_i32 iter_args(%a_block_ptr = %18, %b_block_ptr = %22, %accumulator_0 = %cst) -> (!tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>) : i32 {
105+
%25 = tt.load %a_block_ptr {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x256x32xf32>>
106+
%26 = tt.reshape %25 : tensor<1x256x32xf32> -> tensor<256x32xf32>
107+
%27 = tt.load %b_block_ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x128xf32>>
108+
%28 = tt.dot %26, %27, %cst, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
109+
%29 = arith.addf %accumulator_0, %28 : tensor<256x128xf32>
110+
%30 = tt.advance %a_block_ptr, [%c0_i32, %c0_i32, %c32_i32] : <tensor<1x256x32xf32>>
111+
%31 = tt.advance %b_block_ptr, [%c32_i32, %c0_i32] : <tensor<32x128xf32>>
112+
scf.yield %30, %31, %29 : !tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>
55113
}
56-
// CHECK-LABEL: fuseLoadWithReshape2
57-
// CHECK-NOT: tt.reshape
58-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c512_i64 : i64 to i32
59-
// CHECK: [[MUL:%.*]] = arith.muli [[TRUNC]], %c1_i32 : i32
60-
// CHECK: [[ADD:%.*]] = arith.addi [[MUL]], %c0_i32 : i32
61-
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c32_i32, [[ADD]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
62-
// CHECK: scf.for
63-
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>>
64-
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
114+
%23 = arith.extsi %stride_cm : i32 to i64
115+
%24 = tt.make_tensor_ptr %c_ptr, [%15, %20], [%23, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x128xf32>>
116+
tt.store %24, %accumulator#2 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x128xf32>>
117+
tt.return
65118
}
119+
// CHECK-LABEL: test_matmul
120+
// CHECK-NOT: tt.reshape
121+
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, %c1_i64 : i64
122+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %16 : i64
123+
// CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32
124+
// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c0_i32 : i32
125+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
126+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [%15, [[ADD1]]], [%17, %c1_i64], [%14, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
127+
// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
128+
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
129+
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
130+
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,6 @@ class FuseReshape {
182182
auto tensorType = cast<RankedTensorType>(reshapeOp.getType());
183183
auto newPtrType =
184184
tt::PointerType::get(tensorType, ptrType.getAddressSpace());
185-
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
186-
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
187185

188186
unsigned innermostDimIdx = 0;
189187
ArrayRef<int> order = makeTensorPtrOp.getOrder();
@@ -195,8 +193,15 @@ class FuseReshape {
195193

196194
OpBuilder builder(makeTensorPtrOp);
197195
Location loc = makeTensorPtrOp.getLoc();
196+
Value firstShape = makeTensorPtrOp.getShape().front();
198197
Value firstStride = makeTensorPtrOp.getStrides().front();
199198
Value firstOffset = makeTensorPtrOp.getOffsets().front();
199+
200+
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
201+
newShape[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
202+
loc, builder.create<arith::MulIOp>(loc, firstStride, firstShape),
203+
newShape[innermostDimIdx - 1]);
204+
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
200205
SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());
201206
newOffsets[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
202207
loc,
@@ -372,10 +377,8 @@ class FuseReshape {
372377
SmallVector<Operation *> users(op->getUsers());
373378
if (users.size() > 2 || llvm::none_of(users, [&](Operation *user) {
374379
return user == yieldOp;
375-
})) {
376-
llvm::errs() << "at line " << __LINE__ << "\n";
380+
}))
377381
return false;
378-
}
379382

380383
auto yieldedValUsedAfterLoop = [&op, &yieldOp]() {
381384
auto it =

0 commit comments

Comments
 (0)