@@ -133,27 +133,35 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
133
133
support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
134
134
135
135
if block_ptr :
136
+ load_ops = f"""
137
+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
138
+ %store_val = tt.load %src_ptr {{boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139
+ """
136
140
store_ops = f"""
137
- %M_i64 = arith.constant { M } : i64
138
- %N_i64 = arith.constant { N } : i64
139
- %c1_i64 = arith.constant 1 : i64
140
- %c0_i32 = arith.constant 0 : i32
141
-
142
- %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
143
- tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
141
+ %dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
142
+ tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
144
143
"""
145
144
else :
145
+ load_ops = f"""
146
+ %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
147
+ %src_ptr = tt.addptr %src_base, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
148
+ %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
149
+ """
146
150
store_ops = f"""
147
- %12 = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
148
- %13 = tt.addptr %12 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
149
- tt.store %13 , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
151
+ %dst_base = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
152
+ %dst_ptr = tt.addptr %dst_base , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
153
+ tt.store %dst_ptr , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
150
154
"""
151
155
152
156
ir = f"""
153
157
#layout = { layout }
154
158
module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
155
159
tt.func public @block_store(%src: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}) {{
156
160
161
+ %M_i64 = arith.constant { M } : i64
162
+ %N_i64 = arith.constant { N } : i64
163
+ %c1_i64 = arith.constant 1 : i64
164
+ %c0_i32 = arith.constant 0 : i32
157
165
%stride = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
158
166
%1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
159
167
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
@@ -163,9 +171,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
163
171
%6 = tt.broadcast %3 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
164
172
%7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
165
173
%8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
166
- %9 = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
167
- %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
168
- %store_val = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
174
+ { load_ops }
169
175
170
176
{ store_ops }
171
177
@@ -191,3 +197,5 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
191
197
192
198
if support_block_io :
193
199
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockWrite' in kernel .asm ['llir' ]
200
+ if not block_ptr :
201
+ assert 'spirv_Subgroup2DBlockLoad' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockRead' in kernel .asm ['llir' ]
0 commit comments