@@ -75,16 +75,16 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32,
75
75
76
76
// -----
77
77
78
- // COM: Case 2:
79
- // COM: Checks that DPAS encoding has been forwarded to the store op
80
- // COM: and the ttg.convert_layout operation has been removed
78
+ // COM: Case 2: Similar to Case 1 but the loads do not have the blockIO "row_major" attribute.
79
+ // COM: Checks that DPAS encoding has been forwarded from the dot op to the store op via the loop return values
80
+ // COM: and that the ttg.convert_layout operation has been removed.
81
81
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
82
82
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
83
83
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
84
84
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [1 , 4 ], repCluster = [1 , 1 ], A = [8 , 16 ], B = [16 , 16 ], C = [8 , 16 ]}>
85
85
#dot0 = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth =1 }>
86
86
#dot1 = #ttg.dot_op <{opIdx = 1 , parent = #dpas , kWidth =2 }>
87
- module attributes {" ttg.num-ctas " = 1 : i32 , " ttg.num- warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
87
+ module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
88
88
tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %arg3: i32 , %arg4: i32 , %arg5: i32 , %arg6: i32 , %arg7: i32 , %arg8: i32 ) {
89
89
%c8_i32 = arith.constant 8 : i32
90
90
%c64_i32 = arith.constant 64 : i32
@@ -128,21 +128,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
128
128
%34 = tt.advance %arg12 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #blocked1 >>
129
129
scf.yield %32 , %33 , %34 : tensor <64 x256 xf32 , #dpas >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
130
130
}
131
+ // CHECK: arith.truncf {{.*}} : tensor<64x256xf32, #[[DPAS]]> to tensor<64x256xf16, #[[DPAS]]>
132
+ // CHECK-NOT: ttg.convert_layout
133
+ // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
134
+ // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
131
135
%24 = arith.truncf %23#0 : tensor <64 x256 xf32 , #dpas > to tensor <64 x256 xf16 , #dpas >
132
136
%25 = ttg.convert_layout %24 : tensor <64 x256 xf16 , #dpas > -> tensor <64 x256 xf16 , #blocked1 >
133
137
%26 = arith.extsi %arg8 : i32 to i64
134
- // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
135
138
%27 = tt.make_tensor_ptr %arg2 , [%15 , %20 ], [%26 , %c1_i64 ], [%14 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf16 , #blocked1 >>
136
- // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
137
139
tt.store %27 , %25 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf16 , #blocked1 >>
138
140
tt.return
139
141
}
140
142
}
141
143
142
144
// -----
143
145
144
- // COM: Case 3:
145
- // COM: Checks that DPAS encoding has been forwarded to the store op
146
+ // COM: Case 3: Similar to Case 1 but with an additional store after the loop
147
+ // COM: Checks that DPAS encoding has been forwarded from the dot op to the store op via the loop return values
146
148
// COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP)
147
149
// COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept.
148
150
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -187,6 +189,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
187
189
%21 = arith.extsi %arg7 : i32 to i64
188
190
%22 = tt.make_tensor_ptr %arg1 , [%16 , %20 ], [%21 , %c1_i64 ], [%c0_i32 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked1 >>
189
191
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args (%arg10 = %cst , %arg11 = %18 , %arg12 = %22 ) -> (tensor <64 x256 xf32 , #dpas >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>) : i32 {
192
+ // COM: Layout conversions in the loop should be removed.
193
+ // CHECK: scf.for
194
+ // CHECK-NOT: ttg.convert_layout
195
+ // CHECK: scf.yield
190
196
%28 = tt.load %arg11 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
191
197
%29 = tt.load %arg12 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
192
198
%30 = ttg.convert_layout %28 : tensor <64 x32 xf16 , #blocked > -> tensor <64 x32 xf16 , #dot0 >
@@ -196,43 +202,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
196
202
%34 = tt.advance %arg12 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #blocked1 >>
197
203
scf.yield %32 , %33 , %34 : tensor <64 x256 xf32 , #dpas >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
198
204
}
205
+ // CHECK: arith.truncf
206
+ // CHECK-NOT: ttg.convert_layout
207
+ // CHECK-DAG: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
208
+ // CHECK-DAG: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
209
+ // CHECK-NEXT: tt.store [[PTR1]], {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
210
+ // CHECK-NEXT: tt.load [[PTR2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[BLOCKED]]>>
199
211
%24 = arith.truncf %23#0 : tensor <64 x256 xf32 , #dpas > to tensor <64 x256 xf16 , #dpas >
200
212
%25 = ttg.convert_layout %24 : tensor <64 x256 xf16 , #dpas > -> tensor <64 x256 xf16 , #blocked1 >
201
213
%26 = arith.extsi %arg8 : i32 to i64
202
- // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
203
- // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
204
214
%27 = tt.make_tensor_ptr %arg2 , [%15 , %20 ], [%26 , %c1_i64 ], [%14 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf16 , #blocked1 >>
205
- // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
206
215
tt.store %27 , %25 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf16 , #blocked1 >>
207
216
%35 = tt.load %27 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf16 , #blocked1 >>
217
+ // CHECK-NUM-2: ttg.convert_layout
208
218
%36 = tt.make_tensor_ptr %arg13 , [%15 , %16 ], [%17 , %c1_i64 ], [%14 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x64 xf16 , #blocked >>
209
219
%37 = tt.load %36 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x64 xf16 , #blocked >>
210
220
%38 = ttg.convert_layout %37 : tensor <64 x64 xf16 , #blocked > -> tensor <64 x64 xf16 , #dot0 >
211
221
%39 = ttg.convert_layout %35 : tensor <64 x256 xf16 , #blocked1 > -> tensor <64 x256 xf16 , #dot1 >
212
222
%40 = tt.dot %38 , %39 , %cst , inputPrecision = tf32 : tensor <64 x64 xf16 , #dot0 > * tensor <64 x256 xf16 , #dot1 > -> tensor <64 x256 xf32 , #dpas >
223
+ // CHECK: tt.dot
224
+ // CHECK-NOT: ttg.convert_layout
225
+ // CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf32, #[[DPAS]]>>
226
+ // CHECK: tt.store [[PTR3]], {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf32, #[[DPAS]]>>
213
227
%41 = ttg.convert_layout %40 : tensor <64 x256 xf32 , #dpas > -> tensor <64 x256 xf32 , #blocked1 >
214
- // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf32, #[[DPAS]]>>
215
228
%42 = tt.make_tensor_ptr %arg14 , [%15 , %20 ], [%26 , %c1_i64 ], [%14 , %19 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf32 , #blocked1 >>
216
- // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf32, #[[DPAS]]>>
217
229
tt.store %42 , %41 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf32 , #blocked1 >>
218
230
tt.return
219
231
}
220
232
}
221
233
222
-
223
234
// -----
224
235
225
- // COM: Case 4:
226
- // COM: Checks that DPAS encoding has been forwarded to the store op
227
- // COM: and the ttg.convert_layout operation in the loop has been removed
236
+ // COM: Case 4: Similar to Case 1 but with a convert layout on the dot op return value op in the loop
237
+ // COM: Checks that DPAS encoding has been forwarded from the dot op to the store op through the loop results
238
+ // COM: and the ttg.convert_layout operations in the loop has been removed
228
239
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
229
240
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
230
241
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
231
242
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [1 , 4 ], repCluster = [1 , 1 ], A = [8 , 16 ], B = [16 , 16 ], C = [8 , 16 ]}>
232
243
#dot0 = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth =1 }>
233
244
#dot1 = #ttg.dot_op <{opIdx = 1 , parent = #dpas , kWidth =2 }>
234
245
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
235
- tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %arg3: i32 , %arg4: i32 , % arg5: i32 , %arg6: i32 , %arg7: i32 , %arg8 : i32 ) {
246
+ tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %arg5: i32 ) {
236
247
%c1_i64 = arith.constant 1 : i64
237
248
%c0_i32 = arith.constant 0 : i32
238
249
%c0_i64 = arith.constant 0 : i64
@@ -241,6 +252,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
241
252
%18 = tt.make_tensor_ptr %arg0 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x32 xf16 , #blocked >>
242
253
%22 = tt.make_tensor_ptr %arg1 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked1 >>
243
254
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args (%arg10 = %cst , %arg11 = %18 , %arg12 = %22 ) -> (tensor <64 x256 xf32 , #blocked1 >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>) : i32 {
255
+ // CHECK: scf.for
256
+ // CHECK-NOT: ttg.convert_layout
257
+ // CHECK: scf.yield
244
258
%28 = tt.load %arg11 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
245
259
%29 = tt.load %arg12 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
246
260
%36 = ttg.convert_layout %arg10 : tensor <64 x256 xf32 , #blocked1 > -> tensor <64 x256 xf32 , #dpas >
@@ -249,14 +263,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
249
263
%32 = tt.dot %30 , %31 , %36 , inputPrecision = tf32 : tensor <64 x32 xf16 , #dot0 > * tensor <32 x256 xf16 , #dot1 > -> tensor <64 x256 xf32 , #dpas >
250
264
%33 = tt.advance %arg11 , [%c0_i32 , %c32_i32 ] : <tensor <64 x32 xf16 , #blocked >>
251
265
%34 = tt.advance %arg12 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #blocked1 >>
252
- // CHECK-NOT: ttg.convert_layout
253
266
%35 = ttg.convert_layout %32 : tensor <64 x256 xf32 , #dpas > -> tensor <64 x256 xf32 , #blocked1 >
254
267
scf.yield %35 , %33 , %34 : tensor <64 x256 xf32 , #blocked1 >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
255
268
}
256
- %24 = arith.truncf %23#0 : tensor <64 x256 xf32 , #blocked1 > to tensor <64 x256 xf16 , #blocked1 >
269
+ // CHECK: arith.truncf
270
+ // CHECK-NOT: ttg.convert_layout
257
271
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
258
- %27 = tt.make_tensor_ptr %arg2 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf16 , #blocked1 >>
259
272
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
273
+ %24 = arith.truncf %23#0 : tensor <64 x256 xf32 , #blocked1 > to tensor <64 x256 xf16 , #blocked1 >
274
+ %27 = tt.make_tensor_ptr %arg2 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf16 , #blocked1 >>
260
275
tt.store %27 , %24 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf16 , #blocked1 >>
261
276
tt.return
262
277
}
@@ -270,8 +285,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
270
285
// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
271
286
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
272
287
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
273
- module attributes {" ttg.num-ctas " = 1 : i32 , " ttg.num- warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
274
- tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 >) {
288
+ module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
289
+ tt.func public @store_with_block_pointers (%arg0: !tt.ptr <f16 >) {
275
290
%c8_i32 = arith.constant 8 : i32
276
291
%c64_i64 = arith.constant 64 : i64
277
292
%c1_i64 = arith.constant 1 : i64
@@ -297,8 +312,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
297
312
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
298
313
#blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
299
314
#blocked3 = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
300
- module attributes {" ttg.num-ctas " = 1 : i32 , " ttg.num- warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 , ttig.support_dpas , ttig.support_sg_2d_block } {
301
- tt.func public @test_4866 (%arg0: !tt.ptr <f16 > { tt.divisibility = 16 : i32 } , %arg1: !tt.ptr <f32 > { tt.divisibility = 16 : i32 } , %arg2: i64 ) {
315
+ module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
316
+ tt.func public @test_4866 (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f32 >, %arg2: i64 ) {
302
317
%c1_i32 = arith.constant 1 : i32
303
318
%cst = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf16 , #blocked >
304
319
%cst_0 = arith.constant dense <5.000000e-01 > : tensor <16 x32 xf32 , #blocked1 >
@@ -311,11 +326,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
311
326
%1 = tt.make_tensor_ptr %arg1 , [%arg2 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c32_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <16 x32 xf32 , #blocked2 >>
312
327
%2:2 = scf.for %arg3 = %c0_i32 to %c16_i32 step %c1_i32 iter_args (%arg4 = %0 , %arg5 = %1 ) -> (!tt.ptr <tensor <16 x32 xf16 , #blocked2 >>, !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>) : i32 {
313
328
// CHECK: scf.for {{.*}}
314
- // CHECK: [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
315
- // CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
316
- // CHECK: [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
317
- // CHECK: [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
318
- // CHECK: tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
329
+ // CHECK-NEXT : [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
330
+ // CHECK-NEXT : [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
331
+ // CHECK-NEXT : [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
332
+ // CHECK-NEXT : [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
333
+ // CHECK-NEXT : tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
319
334
%3 = tt.load %arg4 : !tt.ptr <tensor <16 x32 xf16 , #blocked2 >>
320
335
%4 = ttg.convert_layout %3 : tensor <16 x32 xf16 , #blocked2 > -> tensor <16 x32 xf16 , #blocked1 >
321
336
%5 = ttg.convert_layout %cst : tensor <16 x16 xf16 , #blocked > -> tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked3 }>>
0 commit comments