@@ -286,3 +286,47 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
286286 tt.return
287287 }
288288}
289+
290+ // -----
291+
292+ // COM: Fix for issue #4866
293+
294+ // CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
295+ // CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
296+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
297+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
298+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
299+ #blocked3 = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
300+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 , ttig.support_dpas , ttig.support_sg_2d_block } {
301+ tt.func public @test_4866 (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg2: i64 ) {
302+ %c1_i32 = arith.constant 1 : i32
303+ %cst = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf16 , #blocked >
304+ %cst_0 = arith.constant dense <5.000000e-01 > : tensor <16 x32 xf32 , #blocked1 >
305+ %c64_i64 = arith.constant 64 : i64
306+ %c32_i32 = arith.constant 32 : i32
307+ %c0_i32 = arith.constant 0 : i32
308+ %c1_i64 = arith.constant 1 : i64
309+ %c16_i32 = arith.constant 16 : i32
310+ %0 = tt.make_tensor_ptr %arg0 , [%arg2 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c32_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <16 x32 xf16 , #blocked2 >>
311+ %1 = tt.make_tensor_ptr %arg1 , [%arg2 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c32_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <16 x32 xf32 , #blocked2 >>
312+ %2:2 = scf.for %arg3 = %c0_i32 to %c16_i32 step %c1_i32 iter_args (%arg4 = %0 , %arg5 = %1 ) -> (!tt.ptr <tensor <16 x32 xf16 , #blocked2 >>, !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>) : i32 {
313+ // CHECK: scf.for {{.*}}
314+ // CHECK: [[LOAD_RES:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<16x32xf16, #[[BLOCKED1]]>>
315+ // CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD_RES]] : tensor<16x32xf16, #[[BLOCKED1]]> -> tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>>
316+ // CHECK: [[DOT_RES:%.*]] = tt.dot %cst_0, [[CONV1]], %cst : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<16x32xf32, #[[BLOCKED]]>
317+ // CHECK: [[CONV2:%.*]] = ttg.convert_layout [[DOT_RES]] : tensor<16x32xf32, #[[BLOCKED]]> -> tensor<16x32xf32, #[[BLOCKED1]]>
318+ // CHECK: tt.store {{.*}}, [[CONV2]] : !tt.ptr<tensor<16x32xf32, #[[BLOCKED1]]>>
319+ %3 = tt.load %arg4 : !tt.ptr <tensor <16 x32 xf16 , #blocked2 >>
320+ %4 = ttg.convert_layout %3 : tensor <16 x32 xf16 , #blocked2 > -> tensor <16 x32 xf16 , #blocked1 >
321+ %5 = ttg.convert_layout %cst : tensor <16 x16 xf16 , #blocked > -> tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked3 }>>
322+ %6 = ttg.convert_layout %4 : tensor <16 x32 xf16 , #blocked1 > -> tensor <16 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked3 }>>
323+ %7 = ttg.convert_layout %cst_0 : tensor <16 x32 xf32 , #blocked1 > -> tensor <16 x32 xf32 , #blocked3 >
324+ %8 = tt.dot %5 , %6 , %7 : tensor <16 x16 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked3 }>> * tensor <16 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked3 }>> -> tensor <16 x32 xf32 , #blocked3 >
325+ %9 = ttg.convert_layout %8 : tensor <16 x32 xf32 , #blocked3 > -> tensor <16 x32 xf32 , #blocked1 >
326+ %10 = ttg.convert_layout %9 : tensor <16 x32 xf32 , #blocked1 > -> tensor <16 x32 xf32 , #blocked2 >
327+ tt.store %arg5 , %10 : !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>
328+ scf.yield %arg4 , %arg5 : !tt.ptr <tensor <16 x32 xf16 , #blocked2 >>, !tt.ptr <tensor <16 x32 xf32 , #blocked2 >>
329+ }
330+ tt.return
331+ }
332+ }
0 commit comments