@@ -2472,3 +2472,49 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
2472
2472
tt.return
2473
2473
}
2474
2474
}
2475
+
2476
+ // -----
2477
+
2478
+ // COM: Test that the DPAS layout is propagated to the store operation in the presence of an advance operation updating its base pointer.
2479
+ // CHECK-NOT: #ttg.blocked<{.*}>
2480
+ // CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
2481
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
2482
+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
2483
+ #dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [1 , 4 ], repCluster = [1 , 1 ], A = [8 , 16 ], B = [16 , 16 ], C = [8 , 16 ]}>
2484
+ #dot0 = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth =1 }>
2485
+ #dot1 = #ttg.dot_op <{opIdx = 1 , parent = #dpas , kWidth =2 }>
2486
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.support_sg_2d_block" } {
2487
+ // CHECK-LABEL: matmul_kernel_with_block_pointers
2488
+ tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 >, %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %arg3: i32 , %arg4: i32 , %arg5: i32 , %arg6: i32 , %arg7: i32 , %arg8: i32 ) {
2489
+ %c1_i64 = arith.constant 1 : i64
2490
+ %c0_i32 = arith.constant 0 : i32
2491
+ %c0_i64 = arith.constant 0 : i64
2492
+ %c32_i32 = arith.constant 32 : i32
2493
+ %cst = arith.constant dense <0.000000e+00 > : tensor <64 x256 xf32 , #blocked1 >
2494
+ // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
2495
+ // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2496
+ %18 = tt.make_tensor_ptr %arg0 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x32 xf16 , #blocked >>
2497
+ %22 = tt.make_tensor_ptr %arg1 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked1 >>
2498
+ %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args (%arg10 = %cst , %arg11 = %18 , %arg12 = %22 ) -> (tensor <64 x256 xf32 , #blocked1 >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>) : i32 {
2499
+ // CHECK-NOT: ttg.convert_layout
2500
+ %28 = tt.load %arg11 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x32 xf16 , #blocked >>
2501
+ %29 = tt.load %arg12 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
2502
+ %36 = ttg.convert_layout %arg10 : tensor <64 x256 xf32 , #blocked1 > -> tensor <64 x256 xf32 , #dpas >
2503
+ %30 = ttg.convert_layout %28 : tensor <64 x32 xf16 , #blocked > -> tensor <64 x32 xf16 , #dot0 >
2504
+ %31 = ttg.convert_layout %29 : tensor <32 x256 xf16 , #blocked1 > -> tensor <32 x256 xf16 , #dot1 >
2505
+ %32 = tt.dot %30 , %31 , %36 , inputPrecision = tf32 : tensor <64 x32 xf16 , #dot0 > * tensor <32 x256 xf16 , #dot1 > -> tensor <64 x256 xf32 , #dpas >
2506
+ %33 = tt.advance %arg11 , [%c0_i32 , %c32_i32 ] : <tensor <64 x32 xf16 , #blocked >>
2507
+ %34 = tt.advance %arg12 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #blocked1 >>
2508
+ %35 = ttg.convert_layout %32 : tensor <64 x256 xf32 , #dpas > -> tensor <64 x256 xf32 , #blocked1 >
2509
+ scf.yield %35 , %33 , %34 : tensor <64 x256 xf32 , #blocked1 >, !tt.ptr <tensor <64 x32 xf16 , #blocked >>, !tt.ptr <tensor <32 x256 xf16 , #blocked1 >>
2510
+ }
2511
+ %24 = arith.truncf %23#0 : tensor <64 x256 xf32 , #blocked1 > to tensor <64 x256 xf16 , #blocked1 >
2512
+ // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[$DPAS]]>>
2513
+ // CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<64x256xf16, #[[$DPAS]]>>
2514
+ // CHECK: tt.store [[PTR2]], {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[$DPAS]]>>
2515
+ %27 = tt.make_tensor_ptr %arg2 , [%c0_i64 , %c0_i64 ], [%c0_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x256 xf16 , #blocked1 >>
2516
+ %newptr = tt.advance %27 , [%c32_i32 , %c32_i32 ] : <tensor <64 x256 xf16 , #blocked1 >>
2517
+ tt.store %newptr , %24 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <64 x256 xf16 , #blocked1 >>
2518
+ tt.return
2519
+ }
2520
+ }
0 commit comments