@@ -71,7 +71,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
7171
7272// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
7373// COM: Where the 'make_tensor_ptr' result is loop carried.
74- tt.func public @test_matmul (%a_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %b_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %c_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %M: i32 {tt.divisibility = 16 : i32 }, %N: i32 {tt.divisibility = 16 : i32 }, %K: i32 {tt.divisibility = 16 : i32 }, %stride_am: i32 {tt.divisibility = 16 : i32 }, %stride_bk: i32 {tt.divisibility = 16 : i32 }, %stride_cm: i32 {tt.divisibility = 16 : i32 }) {
74+ tt.func public @fuseLoadWithReshape3 (%a_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %b_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %c_ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %M: i32 {tt.divisibility = 16 : i32 }, %N: i32 {tt.divisibility = 16 : i32 }, %K: i32 {tt.divisibility = 16 : i32 }, %stride_am: i32 {tt.divisibility = 16 : i32 }, %stride_bk: i32 {tt.divisibility = 16 : i32 }, %stride_cm: i32 {tt.divisibility = 16 : i32 }) {
7575 %c127_i32 = arith.constant 127 : i32
7676 %c255_i32 = arith.constant 255 : i32
7777 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x128 xf32 >
@@ -119,7 +119,7 @@ tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
119119 tt.store %24 , %accumulator#2 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x128 xf32 >>
120120 tt.return
121121}
122- // CHECK-LABEL: test_matmul
122+ // CHECK-LABEL: fuseLoadWithReshape3
123123// CHECK-NOT: tt.reshape
124124// CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64
125125// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
@@ -132,3 +132,66 @@ tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
132132// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
133133// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
134134// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
135+
136+ // -----
137+
138+ // COM: tt.load -> tt.reshape -> tt.dot chain, in 2 loops.
139+ // COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
140+ tt.func public @fuseLoadWithTrans4 (%arg0: i32 , %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >) {
141+ %c0_i32 = arith.constant 0 : i32
142+ %c1_i32 = arith.constant 1 : i32
143+ %c2_i32 = arith.constant 2 : i32
144+ %c32_i32 = arith.constant 32 : i32
145+ %c1_i64 = arith.constant 1 : i64
146+ %c64_i64 = arith.constant 64 : i64
147+ %c256_i64 = arith.constant 256 : i64
148+ %cst = arith.constant dense <0.000000e+00 > : tensor <64 x64 xf32 >
149+ %7 = tt.make_tensor_ptr %arg1 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x32 xf16 >>
150+ %9 = tt.make_tensor_ptr %arg2 , [%c1_i64 , %c256_i64 , %c64_i64 ], [%c256_i64 , %c64_i64 , %c1_i64 ], [%c0_i32 , %c1_i32 , %c2_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x64 xf16 >>
151+ %10 = tt.advance %7 , [%arg0 , %c0_i32 ] : <tensor <64 x32 xf16 >>
152+ %11 = tt.load %10 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x32 xf16 >>
153+ %res1:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg4 = %arg0 ) -> (i32 ) : i32 {
154+ %adv = tt.advance %9 , [%arg4 , %c0_i32 ] : <tensor <1 x32 x64 xf16 >>
155+ %load = tt.load %adv {boundaryCheck = array<i32 : 1 , 2 >} : !tt.ptr <tensor <1 x32 x64 xf16 >>
156+ %reshape = tt.reshape %load : tensor <1 x32 x64 xf16 > -> tensor <32 x64 xf16 >
157+ %dot = tt.dot %11 , %reshape , %cst , inputPrecision = tf32 : tensor <64 x32 xf16 > * tensor <32 x64 xf16 > -> tensor <64 x64 xf32 >
158+ %add = arith.addi %arg4 , %c32_i32 : i32
159+ scf.yield %add : i32
160+ }
161+ %res2:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg4 = %arg0 ) -> (i32 ) : i32 {
162+ %adv = tt.advance %9 , [%arg4 , %c0_i32 ] : <tensor <1 x32 x64 xf16 >>
163+ %load = tt.load %adv {boundaryCheck = array<i32 : 2 , 1 >} : !tt.ptr <tensor <1 x32 x64 xf16 >>
164+ %reshape = tt.reshape %load : tensor <1 x32 x64 xf16 > -> tensor <32 x64 xf16 >
165+ %dot = tt.dot %11 , %reshape , %cst , inputPrecision = tf32 : tensor <64 x32 xf16 > * tensor <32 x64 xf16 > -> tensor <64 x64 xf32 >
166+ %add = arith.addi %arg4 , %c32_i32 : i32
167+ scf.yield %add : i32
168+ }
169+ tt.return
170+
171+ }
172+ // CHECK-LABEL: fuseLoadWithTrans4
173+ // CHECK-NOT: tt.reshape
174+ // CHECK: [[DIV1:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
175+ // CHECK: [[MUL11:%.*]] = arith.muli %c1_i64, [[DIV1]] : i64
176+ // CHECK: [[ADD11:%.*]] = arith.addi [[MUL11]], %c256_i64 : i64
177+ // CHECK: [[TRUNC1:%.*]] = arith.trunci [[DIV1]] : i64 to i32
178+ // CHECK: [[MUL21:%.*]] = arith.muli %c0_i32, [[TRUNC1]] : i32
179+ // CHECK: [[ADD21:%.*]] = arith.addi [[MUL21]], %c1_i32 : i32
180+ // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD11]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD21]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
181+ // CHECK: [[DIV2:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
182+ // CHECK: [[MUL12:%.*]] = arith.muli %c1_i64, [[DIV2]] : i64
183+ // CHECK: [[ADD12:%.*]] = arith.addi [[MUL12]], %c256_i64 : i64
184+ // CHECK: [[TRUNC2:%.*]] = arith.trunci [[DIV2]] : i64 to i32
185+ // CHECK: [[MUL22:%.*]] = arith.muli %c0_i32, [[TRUNC2]] : i32
186+ // CHECK: [[ADD22:%.*]] = arith.addi [[MUL22]], %c1_i32 : i32
187+ // CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD12]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD22]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
188+ // CHECK: scf.for
189+ // CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>>
190+ // CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x64xf16>>
191+ // CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
192+ // CHECK: scf.yield
193+ // CHECK: scf.for
194+ // CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>>
195+ // CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<32x64xf16>>
196+ // CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
197+ // CHECK: scf.yield
0 commit comments