@@ -118,9 +118,10 @@ def warps_per_cta(layout):
118
118
@pytest .mark .parametrize ("M, N" , [[M , N ] for M , N in itertools .product ([32 , 64 , 128 ], [64 , 128 ])])
119
119
@pytest .mark .parametrize ("dtype_str" , ["float32" , "float16" , "int8" ])
120
120
@pytest .mark .parametrize ("layout" , layouts )
121
- @pytest .mark .parametrize ("block_ptr" , [True , False ])
121
+ @pytest .mark .parametrize ("load_block_ptr, store_block_ptr" , [(True , True ), (False , False ), (True , False ),
122
+ (False , True )])
122
123
@pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
123
- def test_block_store (M , N , dtype_str , layout , block_ptr , device , tmp_path : pathlib .Path ):
124
+ def test_block_io (M , N , dtype_str , layout , load_block_ptr , store_block_ptr , device , tmp_path : pathlib .Path ):
124
125
125
126
warps = warps_per_cta (layout )
126
127
num_warps = int (np .prod (warps ))
@@ -131,41 +132,50 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
131
132
132
133
support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
133
134
134
- if block_ptr :
135
+ if load_block_ptr :
136
+ load_ops = f"""
137
+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
138
+ %store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139
+ """
140
+ else :
141
+ load_ops = f"""
142
+ %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
143
+ %src_ptr = tt.addptr %src_base, %row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
144
+ %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
145
+ """
146
+ if store_block_ptr :
135
147
store_ops = f"""
136
- %M_i64 = arith.constant { M } : i64
137
- %N_i64 = arith.constant { N } : i64
138
- %c1_i64 = arith.constant 1 : i64
139
- %c0_i32 = arith.constant 0 : i32
140
-
141
148
%blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
142
149
tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
143
150
"""
144
151
else :
145
152
store_ops = f"""
146
- %12 = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
147
- %13 = tt.addptr %12 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
148
- tt.store %13 , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
153
+ %dst_base = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
154
+ %dst_ptr = tt.addptr %dst_base , %row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
155
+ tt.store %dst_ptr , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
149
156
"""
150
157
151
158
ir = f"""
152
159
#layout = { layout }
153
160
module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
154
161
tt.func public @block_store(%src: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}) {{
155
162
156
- %stride = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
163
+ %M_i64 = arith.constant { M } : i64
164
+ %N_i64 = arith.constant { N } : i64
165
+ %c1_i64 = arith.constant 1 : i64
166
+ %c0_i32 = arith.constant 0 : i32
167
+
168
+ %stride_N = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
157
169
%1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
158
170
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
159
- %3 = arith.muli %2, %stride : tensor<{ M } x1xi32, #layout>
171
+ %row_stride = arith.muli %2, %stride_N : tensor<{ M } x1xi32, #layout>
160
172
%4 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
161
173
%5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{ N } xi32, #layout>
162
- %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
174
+ %6 = tt.broadcast %row_stride : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
163
175
%7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
164
- %8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
165
- %9 = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
166
- %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
167
- %store_val = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
176
+ %row_major_off = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
168
177
178
+ { load_ops }
169
179
{ store_ops }
170
180
171
181
tt.return
@@ -181,7 +191,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
181
191
182
192
x = torch .empty_like (a )
183
193
184
- temp_file = tmp_path / "test_block_store .ttgir"
194
+ temp_file = tmp_path / "test_block_io .ttgir"
185
195
temp_file .write_text (ir )
186
196
kernel = triton .compile (str (temp_file ))
187
197
0 commit comments