@@ -286,3 +286,47 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
286
286
tt.return
287
287
}
288
288
}
289
+
290
+ // -----
291
+
292
+ // COM: Fix for issue #4866
293
+
294
+ // CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
295
+ // CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
296
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
297
+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
298
+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
299
+ #blocked3 = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
300
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 , ttig.support_dpas , ttig.support_sg_2d_block } {
301
+ tt.func public @test_4866 (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg2: i64 ) {
302
+ %c1_i32 = arith.constant 1 : i32
303
+ %cst = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf16 , #blocked >
304
+ %cst_0 = arith.constant dense <5.000000e-01 > : tensor <16 x32 xf32 , #blocked1 >
305
+ %c64_i64 = arith.constant 64 : i64
306
+ %c32_i32 = arith.constant 32 : i32
307
+ %c0_i32 = arith.constant 0 : i32
308
+ %c1_i64 = arith.constant 1 : i64
309
+ %c16_i32 = arith.constant 16 : i32
310
+ %0 = tt.make_tensor_ptr %arg0 , [%arg2 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c32_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <16 x32 xf16 , #blocked2 >>
311
+ %1 = tt.make_tensor_ptr %arg1 , [%arg2 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c32_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <16 x32 xf32 , #blocked2 >>
312
+ %2:2 = scf.for %arg3 = %c0_i32 to %c16_i32 step %c1_i32 iter_args (%arg4 = %0 , %arg5 = %1 ) -> (!tt.ptr <tensor <16 x32 xf16 , #blocked2 >>, !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>) : i32 {
313
+ // CHECK: scf.for {{.*}}
314
+ // CHECK: [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
315
+ // CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
316
+ // CHECK: [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
317
+ // CHECK: [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
318
+ // CHECK: tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
319
+ %3 = tt.load %arg4 : !tt.ptr <tensor <16 x32 xf16 , #blocked2 >>
320
+ %4 = ttg.convert_layout %3 : tensor <16 x32 xf16 , #blocked2 > -> tensor <16 x32 xf16 , #blocked1 >
321
+ %5 = ttg.convert_layout %cst : tensor <16 x16 xf16 , #blocked > -> tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked3 }>>
322
+ %6 = ttg.convert_layout %4 : tensor <16 x32 xf16 , #blocked1 > -> tensor <16 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked3 }>>
323
+ %7 = ttg.convert_layout %cst_0 : tensor <16 x32 xf32 , #blocked1 > -> tensor <16 x32 xf32 , #blocked3 >
324
+ %8 = tt.dot %5 , %6 , %7 : tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked3 }>> * tensor <16 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked3 }>> -> tensor <16 x32 xf32 , #blocked3 >
325
+ %9 = ttg.convert_layout %8 : tensor <16 x32 xf32 , #blocked3 > -> tensor <16 x32 xf32 , #blocked1 >
326
+ %10 = ttg.convert_layout %9 : tensor <16 x32 xf32 , #blocked1 > -> tensor <16 x32 xf32 , #blocked2 >
327
+ tt.store %arg5 , %10 : !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>
328
+ scf.yield %arg4 , %arg5 : !tt.ptr <tensor <16 x32 xf16 , #blocked2 >>, !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>
329
+ }
330
+ tt.return
331
+ }
332
+ }
0 commit comments