@@ -2518,3 +2518,49 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2518
2518
tt.return
2519
2519
}
2520
2520
}
2521
+
2522
+ // -----
2523
+
2524
+ // CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
2525
+ // CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}>
2526
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
2527
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [4 , 1 ], A = [32 , 8 ], B = [8 , 16 ], C = [32 , 16 ]}>
2528
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_sg_2d_block } {
2529
+ // CHECK-LABEL: matmul_kernel_reshape
2530
+ tt.func public @matmul_kernel_reshape (%arg2: !tt.ptr <f32 >, %arg3: i32 , %arg4: i32 ) {
2531
+ %cst = arith.constant dense <0.000000e+00 > : tensor <64 x64 xf32 , #blocked >
2532
+ %c32_i32 = arith.constant 32 : i32
2533
+ %c0_i32 = arith.constant 0 : i32
2534
+ %c1_i32 = arith.constant 1 : i32
2535
+ %c1_i64 = arith.constant 1 : i64
2536
+ %cst_0 = arith.constant dense <1.000000e+00 > : tensor <64 x64 xf32 , #mma >
2537
+ %1 = arith.extsi %arg4 : i32 to i64
2538
+ %2 = arith.extsi %arg3 : i32 to i64
2539
+
2540
+ // CHECK-DAG: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$DPAS]]>>
2541
+ // CHECK-DAG: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$DPAS]]>>
2542
+ // CHECK-DAG: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$BLOCKED]]>>
2543
+
2544
+ // CHECK-NOT: separator of consecutive DAGs
2545
+ // CHECK-DAG: [[ADV_PTR2:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<64x64xf32, #[[$DPAS]]>>
2546
+ // CHECK-DAG: [[ADV_PTR3:%.*]] = tt.advance [[PTR3]], {{.*}} : <tensor<64x64xf32, #[[$BLOCKED]]>>
2547
+ %3 = tt.make_tensor_ptr %arg2 , [%2 , %1 ], [%1 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x64 xf32 , #blocked >>
2548
+ %4 = tt.advance %3 , [%c0_i32 , %c32_i32 ] : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2549
+
2550
+ // The following 2 stores should use blocked layout.
2551
+ // CHECK-NOT: separator of consecutive DAGs
2552
+ // CHECK-DAG: tt.store [[PTR3]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$BLOCKED]]>>
2553
+ // CHECK-DAG: tt.store [[ADV_PTR3]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$BLOCKED]]>>
2554
+ tt.store %3 , %cst : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2555
+ tt.store %4 , %cst : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2556
+
2557
+ // The following 2 stores should use mma layout
2558
+ // CHECK-NOT: ttg.convert_layout
2559
+ // CHECK-DAG: tt.store [[PTR1]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$DPAS]]>>
2560
+ // CHECK-DAG: tt.store [[ADV_PTR2]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$DPAS]]>>
2561
+ %5 = ttg.convert_layout %cst_0 : tensor <64 x64 xf32 , #mma > -> tensor <64 x64 xf32 , #blocked >
2562
+ tt.store %3 , %5 : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2563
+ tt.store %4 , %5 : !tt.ptr <tensor <64 x64 xf32 , #blocked >>
2564
+ tt.return
2565
+ }
2566
+ }
0 commit comments