@@ -2601,3 +2601,73 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2601
2601
tt.return
2602
2602
}
2603
2603
}
2604
+
2605
+ // -----
2606
+
2607
+ // COM: Test that the DPAS layout is propagated to the store operation with tensor pointers.
2608
+ // CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2609
+ #blocked = #ttg.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
2610
+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [16 , 2 ], order = [1 , 0 ]}>
2611
+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 4 ], warpsPerCTA = [32 , 1 ], order = [1 , 0 ]}>
2612
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
2613
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_dpas , ttig.support_sg_2d_block } {
2614
+ // CHECK-LABEL: matmul_kernel_with_tensor_pointer
2615
+ tt.func public @matmul_kernel_with_tensor_pointer (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %arg3: i32 , %arg4: i32 , %arg5: i32 ) {
2616
+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #blocked >
2617
+ %c1_i32 = arith.constant 1 : i32
2618
+ %c0_i32 = arith.constant 0 : i32
2619
+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x256 xf16 , #blocked1 >
2620
+ %cst_1 = arith.constant dense <0.000000e+00 > : tensor <256 x32 xf16 , #blocked2 >
2621
+ %cst_2 = arith.constant dense <32 > : tensor <256 x32 xi32 , #blocked2 >
2622
+ %3 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
2623
+ %4 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
2624
+ %18 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked2 }>>
2625
+ %19 = tt.expand_dims %18 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked2 }>> -> tensor <1 x32 xi32 , #blocked2 >
2626
+ %23 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x32 x!tt.ptr <f16 >, #blocked2 >
2627
+ %26 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
2628
+ %28 = tt.expand_dims %26 {axis = 1 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <32 x1 xi32 , #blocked1 >
2629
+ %35 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
2630
+ %40 = tt.splat %arg5 : i32 -> tensor <32 x256 xi32 , #blocked1 >
2631
+ %41:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c1_i32 iter_args (%arg10 = %cst , %arg11 = %23 , %arg12 = %35 ) -> (tensor <256 x256 xf32 , #blocked >, tensor <256 x32 x!tt.ptr <f16 >, #blocked2 >, tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >) : i32 {
2632
+ %62 = tt.splat %arg9 : i32 -> tensor <1 x32 xi32 , #blocked2 >
2633
+ %63 = arith.cmpi slt , %19 , %62 : tensor <1 x32 xi32 , #blocked2 >
2634
+ %64 = tt.broadcast %63 : tensor <1 x32 xi1 , #blocked2 > -> tensor <256 x32 xi1 , #blocked2 >
2635
+ %65 = tt.load %arg11 , %64 , %cst_1 : tensor <256 x32 x!tt.ptr <f16 >, #blocked2 >
2636
+ %66 = tt.splat %arg5 : i32 -> tensor <32 x1 xi32 , #blocked1 >
2637
+ %67 = arith.cmpi slt , %28 , %66 : tensor <32 x1 xi32 , #blocked1 >
2638
+ %68 = tt.broadcast %67 : tensor <32 x1 xi1 , #blocked1 > -> tensor <32 x256 xi1 , #blocked1 >
2639
+ %69 = tt.load %arg12 , %68 , %cst_0 : tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
2640
+ %70 = ttg.convert_layout %arg10 : tensor <256 x256 xf32 , #blocked > -> tensor <256 x256 xf32 , #mma >
2641
+ %71 = ttg.convert_layout %65 : tensor <256 x32 xf16 , #blocked2 > -> tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
2642
+ %72 = ttg.convert_layout %69 : tensor <32 x256 xf16 , #blocked1 > -> tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
2643
+ %73 = tt.dot %71 , %72 , %70 , inputPrecision = tf32 : tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
2644
+ %74 = ttg.convert_layout %73 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked >
2645
+ %75 = tt.addptr %arg11 , %cst_2 : tensor <256 x32 x!tt.ptr <f16 >, #blocked2 >, tensor <256 x32 xi32 , #blocked2 >
2646
+ %76 = tt.addptr %arg12 , %40 : tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >, tensor <32 x256 xi32 , #blocked1 >
2647
+ scf.yield %74 , %75 , %76 : tensor <256 x256 xf32 , #blocked >, tensor <256 x32 x!tt.ptr <f16 >, #blocked2 >, tensor <32 x256 x!tt.ptr <f16 >, #blocked1 >
2648
+ }
2649
+ // CHECK: [[SCF:%.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, {{.*}}) : i32 {
2650
+ // CHECK: tt.expand_dims {{.*}} -> tensor<256x1xi32, #[[$DPAS]]>
2651
+ // CHECK-NOT: ttg.convert_layout
2652
+ // CHECK: [[RES:%.*]] = arith.truncf [[SCF]]#0 : tensor<256x256xf32, #[[$DPAS]]> to tensor<256x256xf16, #[[$DPAS]]>
2653
+ // CHECK: tt.store {{.*}}, [[RES]], {{.*}} : tensor<256x256x!tt.ptr<f16>, #[[$DPAS]]>
2654
+ %42 = tt.expand_dims %3 {axis = 1 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <256 x1 xi32 , #blocked1 >
2655
+ %45 = tt.splat %arg2 : !tt.ptr <f16 > -> tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >
2656
+ %46 = tt.addptr %45 , %42 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x1 xi32 , #blocked1 >
2657
+ %47 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x256 xi32 , #blocked1 >
2658
+ %48 = tt.broadcast %46 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <256 x256 x!tt.ptr <f16 >, #blocked1 >
2659
+ %49 = tt.broadcast %47 : tensor <1 x256 xi32 , #blocked1 > -> tensor <256 x256 xi32 , #blocked1 >
2660
+ %50 = tt.addptr %48 , %49 : tensor <256 x256 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x256 xi32 , #blocked1 >
2661
+ %51 = tt.splat %arg3 : i32 -> tensor <256 x1 xi32 , #blocked1 >
2662
+ %52 = arith.cmpi slt , %42 , %51 : tensor <256 x1 xi32 , #blocked1 >
2663
+ %53 = tt.splat %arg4 : i32 -> tensor <1 x256 xi32 , #blocked1 >
2664
+ %54 = arith.cmpi slt , %47 , %53 : tensor <1 x256 xi32 , #blocked1 >
2665
+ %55 = tt.broadcast %52 : tensor <256 x1 xi1 , #blocked1 > -> tensor <256 x256 xi1 , #blocked1 >
2666
+ %56 = tt.broadcast %54 : tensor <1 x256 xi1 , #blocked1 > -> tensor <256 x256 xi1 , #blocked1 >
2667
+ %57 = arith.andi %55 , %56 : tensor <256 x256 xi1 , #blocked1 >
2668
+ %58 = arith.truncf %41#0 : tensor <256 x256 xf32 , #blocked > to tensor <256 x256 xf16 , #blocked >
2669
+ %59 = ttg.convert_layout %58 : tensor <256 x256 xf16 , #blocked > -> tensor <256 x256 xf16 , #blocked1 >
2670
+ tt.store %50 , %59 , %57 : tensor <256 x256 x!tt.ptr <f16 >, #blocked1 >
2671
+ tt.return
2672
+ }
2673
+ }
0 commit comments