Skip to content

Commit 5f674f1

Browse files
committed
Add test case
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 79199c6 commit 5f674f1

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
7171

7272
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
7373
// COM: Where the 'make_tensor_ptr' result is loop carried.
74-
tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) {
74+
tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) {
7575
%c127_i32 = arith.constant 127 : i32
7676
%c255_i32 = arith.constant 255 : i32
7777
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32>
@@ -119,7 +119,7 @@ tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
119119
tt.store %24, %accumulator#2 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x128xf32>>
120120
tt.return
121121
}
122-
// CHECK-LABEL: test_matmul
122+
// CHECK-LABEL: fuseLoadWithReshape3
123123
// CHECK-NOT: tt.reshape
124124
// CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64
125125
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
@@ -132,3 +132,66 @@ tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
132132
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
133133
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
134134
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
135+
136+
// -----
137+
138+
// COM: tt.load -> tt.reshape -> tt.dot chain, in 2 loops.
139+
// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
140+
tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
141+
%c0_i32 = arith.constant 0 : i32
142+
%c1_i32 = arith.constant 1 : i32
143+
%c2_i32 = arith.constant 2 : i32
144+
%c32_i32 = arith.constant 32 : i32
145+
%c1_i64 = arith.constant 1 : i64
146+
%c64_i64 = arith.constant 64 : i64
147+
%c256_i64 = arith.constant 256 : i64
148+
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
149+
%7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16>>
150+
%9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c256_i64, %c64_i64], [%c256_i64, %c64_i64, %c1_i64], [%c0_i32, %c1_i32, %c2_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x64xf16>>
151+
%10 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x32xf16>>
152+
%11 = tt.load %10 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16>>
153+
%res1:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 {
154+
%adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>>
155+
%load = tt.load %adv {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x64xf16>>
156+
%reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16>
157+
%dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
158+
%add = arith.addi %arg4, %c32_i32 : i32
159+
scf.yield %add : i32
160+
}
161+
%res2:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 {
162+
%adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>>
163+
%load = tt.load %adv {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x32x64xf16>>
164+
%reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16>
165+
%dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
166+
%add = arith.addi %arg4, %c32_i32 : i32
167+
scf.yield %add : i32
168+
}
169+
tt.return
170+
171+
}
172+
// CHECK-LABEL: fuseLoadWithTrans4
173+
// CHECK-NOT: tt.reshape
174+
// CHECK: [[DIV1:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
175+
// CHECK: [[MUL11:%.*]] = arith.muli %c1_i64, [[DIV1]] : i64
176+
// CHECK: [[ADD11:%.*]] = arith.addi [[MUL11]], %c256_i64 : i64
177+
// CHECK: [[TRUNC1:%.*]] = arith.trunci [[DIV1]] : i64 to i32
178+
// CHECK: [[MUL21:%.*]] = arith.muli %c0_i32, [[TRUNC1]] : i32
179+
// CHECK: [[ADD21:%.*]] = arith.addi [[MUL21]], %c1_i32 : i32
180+
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD11]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD21]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
181+
// CHECK: [[DIV2:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
182+
// CHECK: [[MUL12:%.*]] = arith.muli %c1_i64, [[DIV2]] : i64
183+
// CHECK: [[ADD12:%.*]] = arith.addi [[MUL12]], %c256_i64 : i64
184+
// CHECK: [[TRUNC2:%.*]] = arith.trunci [[DIV2]] : i64 to i32
185+
// CHECK: [[MUL22:%.*]] = arith.muli %c0_i32, [[TRUNC2]] : i32
186+
// CHECK: [[ADD22:%.*]] = arith.addi [[MUL22]], %c1_i32 : i32
187+
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD12]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD22]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
188+
// CHECK: scf.for
189+
// CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>>
190+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x64xf16>>
191+
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
192+
// CHECK: scf.yield
193+
// CHECK: scf.for
194+
// CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>>
195+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<32x64xf16>>
196+
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
197+
// CHECK: scf.yield

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class FuseReshape {
5252
// Collect def-use chains originating at a `MakeTensorPtrOp` operation
5353
// and terminating at a candidate `tt::ReshapeOp` operation.
5454
// Note: A candidate `reshapeOp` must use the result of a `loadOp` using a
55-
// ptr created the `MakeTensorPtrOp` rooting the def-use chain.
55+
// ptr created by the `MakeTensorPtrOp` rooting the def-use chain.
5656
DefUseChainManager manager;
5757
moduleOp.walk([&](tt::ReshapeOp reshapeOp) {
5858
if (isCandidate(reshapeOp)) {
@@ -344,7 +344,7 @@ class FuseReshape {
344344

345345
// Ensure the load boundary check doesn't check the outermost dimension.
346346
return llvm::none_of(loadOp.getBoundaryCheck(),
347-
[&](int val) { return val == 0; });
347+
[](int val) { return val == 0; });
348348
}
349349

350350
// Prune chains that cannot be handled during fusion. For example, operations

0 commit comments

Comments
 (0)