@@ -20,22 +20,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
20
20
%c64_i32 = arith.constant 64 : i32
21
21
%c5120_i32 = arith.constant 5120 : i32
22
22
%cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #blocked >
23
- %0 = tt.get_program_id x : i32
24
- %1 = arith.divsi %0 , %c64_i32 : i32
25
- %2 = arith.muli %1 , %c4_i32 : i32
26
- %3 = arith.subi %c4_i32 , %2 : i32
27
- %4 = arith.minsi %3 , %c4_i32 : i32
28
- %5 = arith.remsi %0 , %4 : i32
29
- %6 = arith.addi %2 , %5 : i32
30
- %7 = arith.remsi %0 , %c64_i32 : i32
31
- %8 = arith.divsi %7 , %4 : i32
32
- %9 = arith.muli %6 , %c256_i32 : i32
23
+
33
24
// CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]>>
34
- %10 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c5120_i64 ], [%c5120_i64 , %c1_i64 ], [%9 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xf16 , #blocked1 >>
35
- %11 = arith.muli %8 , %c256_i32 : i32
25
+ %10 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c5120_i64 ], [%c5120_i64 , %c1_i64 ], [%c256_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xf16 , #blocked1 >>
36
26
// CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]>>
37
- %12 = tt.make_tensor_ptr %arg1 , [%c5120_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%c0_i32 , %11 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked2 >>
38
- // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]])
27
+ %12 = tt.make_tensor_ptr %arg1 , [%c5120_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%c0_i32 , %c256_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked2 >>
28
+ // CHECK: %[[RES:.*]]:3 = scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]])
39
29
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %10 , %arg6 = %12 ) -> (tensor <256 x256 xf32 , #blocked >, !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>, !tt.ptr <tensor <32 x256 xf16 , #blocked2 >>) : i32 {
40
30
%17 = tt.load %arg5 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>
41
31
// CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]>>
@@ -58,7 +48,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
58
48
// CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]]
59
49
scf.yield %25 , %26 , %27 : tensor <256 x256 xf32 , #blocked >, !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>, !tt.ptr <tensor <32 x256 xf16 , #blocked2 >>
60
50
}
61
- %14 = tt.make_tensor_ptr %arg2 , [%c1024_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%9 , %11 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x256 xf16 , #blocked2 >>
51
+ %14 = tt.make_tensor_ptr %arg2 , [%c1024_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%c0_i32 , %c256_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x256 xf16 , #blocked2 >>
52
+ // CHECK aritch.truncf %[[RES]]#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
62
53
%15 = arith.truncf %13#0 : tensor <256 x256 xf32 , #blocked > to tensor <256 x256 xf16 , #blocked >
63
54
%16 = ttg.convert_layout %15 : tensor <256 x256 xf16 , #blocked > -> tensor <256 x256 xf16 , #blocked2 >
64
55
tt.store %14 , %16 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x256 xf16 , #blocked2 >>
0 commit comments