@@ -2518,3 +2518,49 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
25182518 tt.return
25192519 }
25202520}
2521+
2522+ // -----
2523+
2524+ // CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
2525+ // CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}>
2526+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
2527+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [4 , 1 ], A = [32 , 8 ], B = [8 , 16 ], C = [32 , 16 ]}>
2528+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_sg_2d_block } {
2529+ // CHECK-LABEL: matmul_kernel_reshape
2530+ tt.func public @matmul_kernel_reshape (%arg2: !tt.ptr <f32 >, %arg3: i32 , %arg4: i32 ) {
2531+ %cst = arith.constant dense <0.000000e+00 > : tensor <64 x64 xf32 , #blocked >
2532+ %c32_i32 = arith.constant 32 : i32
2533+ %c0_i32 = arith.constant 0 : i32
2534+ %c1_i32 = arith.constant 1 : i32
2535+ %c1_i64 = arith.constant 1 : i64
2536+ %cst_0 = arith.constant dense <1.000000e+00 > : tensor <64 x64 xf32 , #mma >
2537+ %1 = arith.extsi %arg4 : i32 to i64
2538+ %2 = arith.extsi %arg3 : i32 to i64
2539+
2540+ // CHECK-DAG: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$DPAS]]>>
2541+ // CHECK-DAG: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$DPAS]]>>
2542+ // CHECK-DAG: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$BLOCKED]]>>
2543+
2544+ // CHECK-NOT: separator of consecutive DAGs
2545+ // CHECK-DAG: [[ADV_PTR2:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<64x64xf32, #[[$DPAS]]>>
2546+ // CHECK-DAG: [[ADV_PTR3:%.*]] = tt.advance [[PTR3]], {{.*}} : <tensor<64x64xf32, #[[$BLOCKED]]>>
2547+ %3 = tt.make_tensor_ptr %arg2 , [%2 , %1 ], [%1 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x64 xf32 , #blocked >>
2548+ %4 = tt.advance %3 , [%c0_i32 , %c32_i32 ] : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2549+
2550+ // The following 2 stores should use blocked layout.
2551+ // CHECK-NOT: separator of consecutive DAGs
2552+ // CHECK-DAG: tt.store [[PTR3]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$BLOCKED]]>>
2553+ // CHECK-DAG: tt.store [[ADV_PTR3]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$BLOCKED]]>>
2554+ tt.store %3 , %cst : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2555+ tt.store %4 , %cst : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2556+
2557+ // The following 2 stores should use mma layout
2558+ // CHECK-NOT: ttg.convert_layout
2559+ // CHECK-DAG: tt.store [[PTR1]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$DPAS]]>>
2560+ // CHECK-DAG: tt.store [[ADV_PTR2]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$DPAS]]>>
2561+ %5 = ttg.convert_layout %cst_0 : tensor <64 x64 xf32 , #mma > -> tensor <64 x64 xf32 , #blocked >
2562+ tt.store %3 , %5 : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2563+ tt.store %4 , %5 : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2564+ tt.return
2565+ }
2566+ }
0 commit comments