@@ -118,9 +118,10 @@ def warps_per_cta(layout):
118118@pytest .mark .parametrize ("M, N" , [[M , N ] for M , N in itertools .product ([32 , 64 , 128 ], [64 , 128 ])])
119119@pytest .mark .parametrize ("dtype_str" , ["float32" , "float16" , "int8" ])
120120@pytest .mark .parametrize ("layout" , layouts )
121- @pytest .mark .parametrize ("block_ptr" , [True , False ])
121+ @pytest .mark .parametrize ("load_block_ptr, store_block_ptr" , [(True , True ), (False , False ), (True , False ),
122+ (False , True )])
122123@pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
123- def test_block_store (M , N , dtype_str , layout , block_ptr , device , tmp_path : pathlib .Path ):
124+ def test_block_io (M , N , dtype_str , layout , load_block_ptr , store_block_ptr , device , tmp_path : pathlib .Path ):
124125
125126 warps = warps_per_cta (layout )
126127 num_warps = int (np .prod (warps ))
@@ -131,41 +132,50 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
131132
132133 support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
133134
134- if block_ptr :
135+ if load_block_ptr :
136+ load_ops = f"""
137+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
138+ %store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139+ """
140+ else :
141+ load_ops = f"""
142+ %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
143+ %src_ptr = tt.addptr %src_base, %row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
144+ %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
145+ """
146+ if store_block_ptr :
135147 store_ops = f"""
136- %M_i64 = arith.constant { M } : i64
137- %N_i64 = arith.constant { N } : i64
138- %c1_i64 = arith.constant 1 : i64
139- %c0_i32 = arith.constant 0 : i32
140-
141148 %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
142149 tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
143150 """
144151 else :
145152 store_ops = f"""
146- %12 = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
147- %13 = tt.addptr %12 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
148- tt.store %13 , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
153+ %dst_base = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
154+ %dst_ptr = tt.addptr %dst_base , %row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
155+ tt.store %dst_ptr , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
149156 """
150157
151158 ir = f"""
152159 #layout = { layout }
153160 module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
154161 tt.func public @block_store(%src: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}) {{
155162
156- %stride = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
163+ %M_i64 = arith.constant { M } : i64
164+ %N_i64 = arith.constant { N } : i64
165+ %c1_i64 = arith.constant 1 : i64
166+ %c0_i32 = arith.constant 0 : i32
167+
168+ %stride_N = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
157169 %1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
158170 %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
159- %3 = arith.muli %2, %stride : tensor<{ M } x1xi32, #layout>
171+ %row_stride = arith.muli %2, %stride_N : tensor<{ M } x1xi32, #layout>
160172 %4 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
161173 %5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{ N } xi32, #layout>
162- %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
174+ %6 = tt.broadcast %row_stride : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
163175 %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
164- %8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
165- %9 = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
166- %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
167- %store_val = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
176+ %row_major_off = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
168177
178+ { load_ops }
169179 { store_ops }
170180
171181 tt.return
@@ -181,7 +191,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
181191
182192 x = torch .empty_like (a )
183193
184- temp_file = tmp_path / "test_block_store .ttgir"
194+ temp_file = tmp_path / "test_block_io .ttgir"
185195 temp_file .write_text (ir )
186196 kernel = triton .compile (str (temp_file ))
187197
0 commit comments