@@ -43,7 +43,7 @@ def warps_per_cta(layout):
43
43
@pytest .mark .parametrize ("M, N" , [[M , N ] for M , N in itertools .product ([32 , 64 , 128 , 256 ], [32 , 64 , 128 , 256 ])])
44
44
@pytest .mark .parametrize ("dtype_str" , ["float32" , "float16" , "int8" ])
45
45
@pytest .mark .parametrize ("layout" , layouts )
46
- @pytest .mark .skipif (not is_xpu (), reason = "Block load tests are specific to the XPU backend" )
46
+ @pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
47
47
def test_tensor_pointer_block_store (M , N , dtype_str , layout , device , tmp_path : pathlib .Path ):
48
48
49
49
warps = warps_per_cta (layout )
@@ -62,43 +62,43 @@ def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: p
62
62
#dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = { A_width } }}>
63
63
#dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = { B_width } }}>
64
64
module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
65
- tt.func public @tensor_pointer_block_load (%arg0: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ ty } > {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ ty } > {{tt.divisibility = 16: i32}}) {{
65
+ tt.func public @tensor_pointer_block_store (%arg0: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ ty } > {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ ty } > {{tt.divisibility = 16: i32}}) {{
66
66
67
67
// A matrix
68
68
%stride_a = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #dot_a>
69
69
%1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>>
70
70
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{ M } x1xi32, #dot_a>
71
- %4 = arith.muli %2, %stride_a : tensor<{ M } x1xi32, #dot_a>
72
- %5 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73
- %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{ N } xi32, #dot_a>
74
- %7 = tt.broadcast %4 : tensor<{ M } x1xi32, #dot_a> -> tensor<{ M } x{ N } xi32, #dot_a>
75
- %8 = tt.broadcast %6 : tensor<1x{ N } xi32, #dot_a> -> tensor<{ M } x{ N } xi32, #dot_a>
76
- %9 = arith.addi %7 , %8 : tensor<{ M } x{ N } xi32, #dot_a>
77
-
78
- %10 = tt.splat %arg0 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
79
- %11 = tt.addptr %10 , %9 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>, tensor<{ M } x{ N } xi32, #dot_a>
80
- %12 = tt.load %11 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
81
- %13 = tt.splat %arg1 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
82
- %14 = tt.addptr %13 , %9 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>, tensor<{ M } x{ N } xi32, #dot_a>
83
- tt.store %14 , %12 {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
71
+ %3 = arith.muli %2, %stride_a : tensor<{ M } x1xi32, #dot_a>
72
+ %4 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>>
73
+ %5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{ N } xi32, #dot_a>
74
+ %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #dot_a> -> tensor<{ M } x{ N } xi32, #dot_a>
75
+ %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #dot_a> -> tensor<{ M } x{ N } xi32, #dot_a>
76
+ %8 = arith.addi %6 , %7 : tensor<{ M } x{ N } xi32, #dot_a>
77
+
78
+ %9 = tt.splat %arg0 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
79
+ %10 = tt.addptr %9 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>, tensor<{ M } x{ N } xi32, #dot_a>
80
+ %11 = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
81
+ %12 = tt.splat %arg1 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
82
+ %13 = tt.addptr %12 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>, tensor<{ M } x{ N } xi32, #dot_a>
83
+ tt.store %13 , %11 {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_a>
84
84
85
85
// B matrix
86
86
%stride_b = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #dot_b>
87
- %22 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0 , parent = #dot_b}}>>
88
- %44 = tt.make_range {{end = { M } : i32, start = 0 : i32 }} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>>
89
- %46 = tt.expand_dims %44 {{axis = 1 : i32}} : tensor< { M } xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{ M } x1xi32, #dot_b>
90
- %49 = arith.muli %46, %stride_b : tensor<{ M } x1xi32 , #dot_b>
91
- %50 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{ N } xi32, #dot_b>
92
- %51 = tt.broadcast %49 : tensor<{ M } x1xi32, #dot_b> -> tensor<{ M } x{ N } xi32, #dot_b>
93
- %52 = tt.broadcast %50 : tensor<1x{ N } xi32, #dot_b> -> tensor<{ M } x{ N } xi32, #dot_b>
94
- %53 = arith.addi %51 , %52 : tensor<{ M } x{ N } xi32, #dot_b>
95
-
96
- %54 = tt.splat %arg2 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
97
- %55 = tt.addptr %54 , %53 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>, tensor<{ M } x{ N } xi32, #dot_b>
98
- %56 = tt.load %55 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
99
- %57 = tt.splat %arg3 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
100
- %58 = tt.addptr %57 , %53 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>, tensor<{ M } x{ N } xi32, #dot_b>
101
- tt.store %58 , %56 {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
87
+ %21 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1 , parent = #dot_b}}>>
88
+ %22 = tt.expand_dims %21 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor< { M } x1xi32, #dot_b >
89
+ %23 = arith.muli %22, %stride_b : tensor<{ M } x1xi32, #dot_b>
90
+ %24 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32 , #ttg.slice<{{dim = 0, parent = # dot_b}}> >
91
+ %25 = tt.expand_dims %24 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{ N } xi32, #dot_b>
92
+ %26 = tt.broadcast %23 : tensor<{ M } x1xi32, #dot_b> -> tensor<{ M } x{ N } xi32, #dot_b>
93
+ %27 = tt.broadcast %25 : tensor<1x{ N } xi32, #dot_b> -> tensor<{ M } x{ N } xi32, #dot_b>
94
+ %28 = arith.addi %26 , %27 : tensor<{ M } x{ N } xi32, #dot_b>
95
+
96
+ %29 = tt.splat %arg2 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
97
+ %30 = tt.addptr %29 , %28 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>, tensor<{ M } x{ N } xi32, #dot_b>
98
+ %31 = tt.load %30 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
99
+ %32 = tt.splat %arg3 : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
100
+ %33 = tt.addptr %32 , %28 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>, tensor<{ M } x{ N } xi32, #dot_b>
101
+ tt.store %33 , %31 {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #dot_b>
102
102
103
103
tt.return
104
104
}}
@@ -120,5 +120,3 @@ def test_tensor_pointer_block_store(M, N, dtype_str, layout, device, tmp_path: p
120
120
121
121
kernel [(1 , 1 , 1 )](a , x , a , y )
122
122
assert torch .equal (a , x ) and torch .equal (a , y )
123
-
124
- temp_file .unlink ()
0 commit comments