Skip to content

Commit 3f2cc86

Browse files
etiottoCopilot
andauthored
[RemoveLayoutConversions]: Update unit tests (#4900)
This PR updates unit tests for the Intel version of the "RemoveLayoutConversions" optimization pass. The changes focus on improving test documentation and verification patterns to better validate the removal of unnecessary layout conversion operations. --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent d5ee4fc commit 3f2cc86

File tree

1 file changed

+45
-30
lines changed

1 file changed

+45
-30
lines changed

test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,16 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
7575

7676
// -----
7777

78-
// COM: Case 2:
79-
// COM: Checks that DPAS encoding has been forwarded to the store op
80-
// COM: and the ttg.convert_layout operation has been removed
78+
// COM: Case 2: Similar to Case 1 but the loads do not have the blockIO "row_major" attribute.
79+
// COM: Checks that DPAS encoding has been forwarded from the dot op to the store op via the loop return values
80+
// COM: and that the ttg.convert_layout operation has been removed.
8181
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
8282
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
8383
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
8484
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
8585
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
8686
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
87-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
87+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
8888
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
8989
%c8_i32 = arith.constant 8 : i32
9090
%c64_i32 = arith.constant 64 : i32
@@ -128,21 +128,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
128128
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
129129
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
130130
}
131+
// CHECK: arith.truncf {{.*}} : tensor<64x256xf32, #[[DPAS]]> to tensor<64x256xf16, #[[DPAS]]>
132+
// CHECK-NOT: ttg.convert_layout
133+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
134+
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
131135
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
132136
%25 = ttg.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
133137
%26 = arith.extsi %arg8 : i32 to i64
134-
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
135138
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
136-
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
137139
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
138140
tt.return
139141
}
140142
}
141143

142144
// -----
143145

144-
// COM: Case 3:
145-
// COM: Checks that DPAS encoding has been forwarded to the store op
146+
// COM: Case 3: Similar to Case 1 but with an additional store after the loop
147+
// COM: Checks that DPAS encoding has been forwarded from the dot op to the store op via the loop return values
146148
// COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP)
147149
// COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept.
148150
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -187,6 +189,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
187189
%21 = arith.extsi %arg7 : i32 to i64
188190
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
189191
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
192+
// COM: Layout conversions in the loop should be removed.
193+
// CHECK: scf.for
194+
// CHECK-NOT: ttg.convert_layout
195+
// CHECK: scf.yield
190196
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
191197
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
192198
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
@@ -196,43 +202,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
196202
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
197203
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
198204
}
205+
// CHECK: arith.truncf
206+
// CHECK-NOT: ttg.convert_layout
207+
// CHECK-DAG: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
208+
// CHECK-DAG: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
209+
// CHECK-NEXT: tt.store [[PTR1]], {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
210+
// CHECK-NEXT: tt.load [[PTR2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[BLOCKED]]>>
199211
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
200212
%25 = ttg.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
201213
%26 = arith.extsi %arg8 : i32 to i64
202-
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
203-
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
204214
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
205-
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
206215
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
207216
%35 = tt.load %27 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
217+
// CHECK-NUM-2: ttg.convert_layout
208218
%36 = tt.make_tensor_ptr %arg13, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #blocked>>
209219
%37 = tt.load %36 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xf16, #blocked>>
210220
%38 = ttg.convert_layout %37 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #dot0>
211221
%39 = ttg.convert_layout %35 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #dot1>
212222
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<64x64xf16, #dot0> * tensor<64x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
223+
// CHECK: tt.dot
224+
// CHECK-NOT: ttg.convert_layout
225+
// CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf32, #[[DPAS]]>>
226+
// CHECK: tt.store [[PTR3]], {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf32, #[[DPAS]]>>
213227
%41 = ttg.convert_layout %40 : tensor<64x256xf32, #dpas> -> tensor<64x256xf32, #blocked1>
214-
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf32, #[[DPAS]]>>
215228
%42 = tt.make_tensor_ptr %arg14, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf32, #blocked1>>
216-
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf32, #[[DPAS]]>>
217229
tt.store %42, %41 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf32, #blocked1>>
218230
tt.return
219231
}
220232
}
221233

222-
223234
// -----
224235

225-
// COM: Case 4:
226-
// COM: Checks that DPAS encoding has been forwarded to the store op
227-
// COM: and the ttg.convert_layout operation in the loop has been removed
236+
// COM: Case 4: Similar to Case 1 but with a convert layout on the dot op return value op in the loop
237+
// COM: Checks that DPAS encoding has been forwarded from the dot op to the store op through the loop results
238+
// COM: and the ttg.convert_layout operations in the loop has been removed
228239
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
229240
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
230241
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
231242
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
232243
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
233244
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
234245
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
235-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
246+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
236247
%c1_i64 = arith.constant 1 : i64
237248
%c0_i32 = arith.constant 0 : i32
238249
%c0_i64 = arith.constant 0 : i64
@@ -241,6 +252,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
241252
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
242253
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
243254
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
255+
// CHECK: scf.for
256+
// CHECK-NOT: ttg.convert_layout
257+
// CHECK: scf.yield
244258
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
245259
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
246260
%36 = ttg.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas>
@@ -249,14 +263,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
249263
%32 = tt.dot %30, %31, %36, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
250264
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
251265
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
252-
// CHECK-NOT: ttg.convert_layout
253266
%35 = ttg.convert_layout %32 : tensor<64x256xf32, #dpas> -> tensor<64x256xf32, #blocked1>
254267
scf.yield %35, %33, %34 : tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
255268
}
256-
%24 = arith.truncf %23#0 : tensor<64x256xf32, #blocked1> to tensor<64x256xf16, #blocked1>
269+
// CHECK: arith.truncf
270+
// CHECK-NOT: ttg.convert_layout
257271
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
258-
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
259272
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
273+
%24 = arith.truncf %23#0 : tensor<64x256xf32, #blocked1> to tensor<64x256xf16, #blocked1>
274+
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
260275
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
261276
tt.return
262277
}
@@ -270,8 +285,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
270285
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
271286
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
272287
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
273-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
274-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>) {
288+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
289+
tt.func public @store_with_block_pointers(%arg0: !tt.ptr<f16>) {
275290
%c8_i32 = arith.constant 8 : i32
276291
%c64_i64 = arith.constant 64 : i64
277292
%c1_i64 = arith.constant 1 : i64
@@ -297,8 +312,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
297312
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
298313
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
299314
#blocked3 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
300-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
301-
tt.func public @test_4866(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i64) {
315+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
316+
tt.func public @test_4866(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f32>, %arg2: i64) {
302317
%c1_i32 = arith.constant 1 : i32
303318
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #blocked>
304319
%cst_0 = arith.constant dense<5.000000e-01> : tensor<16x32xf32, #blocked1>
@@ -311,11 +326,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
311326
%1 = tt.make_tensor_ptr %arg1, [%arg2, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c32_i32] {order = array<i32: 1, 0>} : <tensor<16x32xf32, #blocked2>>
312327
%2:2 = scf.for %arg3 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %1) -> (!tt.ptr<tensor<16x32xf16, #blocked2>>, !tt.ptr<tensor<16x32xf32, #blocked2>>) : i32 {
313328
// CHECK: scf.for {{.*}}
314-
// CHECK: [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
315-
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
316-
// CHECK: [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
317-
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
318-
// CHECK: tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
329+
// CHECK-NEXT: [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
330+
// CHECK-NEXT: [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
331+
// CHECK-NEXT: [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
332+
// CHECK-NEXT: [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
333+
// CHECK-NEXT: tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
319334
%3 = tt.load %arg4 : !tt.ptr<tensor<16x32xf16, #blocked2>>
320335
%4 = ttg.convert_layout %3 : tensor<16x32xf16, #blocked2> -> tensor<16x32xf16, #blocked1>
321336
%5 = ttg.convert_layout %cst : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>>

0 commit comments

Comments
 (0)