@@ -207,41 +207,39 @@ struct group_row_reduce_t {
207
207
}
208
208
};
209
209
210
+ // Set default load type to block_2d because only the block_2d load/store will
211
+ // ensure boundary safety.
210
212
template <typename scalar_t , typename tile_desc_t , typename mem_desc_t >
211
213
void store_tile (subgroup::tile_t <scalar_t , tile_desc_t >* src, mem_desc_t dst) {
212
- using store_t = subgroup::mem_payload_t <
213
- mem_desc_t , tile_desc_t ,
214
- subgroup::msg_type_v<tile_desc_t , mem_desc_t ::space>, gpu_arch::Xe>;
214
+ using store_t = subgroup::mem_payload_t <mem_desc_t , tile_desc_t ,
215
+ msg_type::block_2d, gpu_arch::Xe>;
215
216
store_t store (dst);
216
217
subgroup::tile_store (*src, store);
217
218
}
218
219
219
220
template <typename scalar_t , typename tile_desc_t , typename mem_desc_t >
220
221
void store_tile (subgroup::tile_t <scalar_t , tile_desc_t >* src, mem_desc_t dst,
221
222
int32_t tile_offset_x, int32_t tile_offset_y) {
222
- using store_t = subgroup::mem_payload_t <
223
- mem_desc_t , tile_desc_t ,
224
- subgroup::msg_type_v<tile_desc_t , mem_desc_t ::space>, gpu_arch::Xe>;
223
+ using store_t = subgroup::mem_payload_t <mem_desc_t , tile_desc_t ,
224
+ msg_type::block_2d, gpu_arch::Xe>;
225
225
dst.update_coord (tile_offset_x, tile_offset_y);
226
226
store_t store (dst);
227
227
subgroup::tile_store (*src, store);
228
228
}
229
229
230
230
template <typename scalar_t , typename tile_desc_t , typename mem_desc_t >
231
231
void load_tile (subgroup::tile_t <scalar_t , tile_desc_t >* dst, mem_desc_t src) {
232
- using load_t = subgroup::mem_payload_t <
233
- mem_desc_t , tile_desc_t ,
234
- subgroup::msg_type_v<tile_desc_t , mem_desc_t ::space>, gpu_arch::Xe>;
232
+ using load_t = subgroup::mem_payload_t <mem_desc_t , tile_desc_t ,
233
+ msg_type::block_2d, gpu_arch::Xe>;
235
234
load_t load (src);
236
235
subgroup::tile_load (*dst, load);
237
236
}
238
237
239
238
template <typename scalar_t , typename tile_desc_t , typename mem_desc_t >
240
239
void load_tile (subgroup::tile_t <scalar_t , tile_desc_t >* dst, mem_desc_t src,
241
240
int32_t tile_offset_x, int32_t tile_offset_y) {
242
- using load_t = subgroup::mem_payload_t <
243
- mem_desc_t , tile_desc_t ,
244
- subgroup::msg_type_v<tile_desc_t , mem_desc_t ::space>, gpu_arch::Xe>;
241
+ using load_t = subgroup::mem_payload_t <mem_desc_t , tile_desc_t ,
242
+ msg_type::block_2d, gpu_arch::Xe>;
245
243
src.update_coord (tile_offset_x, tile_offset_y);
246
244
load_t load (src);
247
245
subgroup::tile_load (*dst, load);
0 commit comments