Skip to content

Commit 828be3b

Browse files
committed
apply group broadcast to KV private sort
Signed-off-by: jinge90 <[email protected]>
1 parent 34f6c35 commit 828be3b

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

libdevice/group_helper.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@ static inline void group_barrier() {
3030
__spv::MemorySemanticsMask::WorkgroupMemory |
3131
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
3232
}
33+
34+
static inline uint64_t group_broadcast(uint64_t x) {
35+
return __spirv_GroupBroadcast(__spv::Scope::Flag::Workgroup, x, 0);
36+
}
3337
#endif // __SPIR__ || __SPIRV__

libdevice/sort_helper.hpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -296,34 +296,38 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
296296
uint8_t *scratch, Compare comp) {
297297
const size_t local_idx = __get_wg_local_linear_id();
298298
const size_t wg_size = __get_wg_local_range();
299-
KeyT *temp_key_beg = reinterpret_cast<KeyT *>(scratch);
300-
uint64_t temp_val_unaligned =
301-
reinterpret_cast<uint64_t>(scratch + 2 * wg_size * n * sizeof(KeyT));
302-
ValT *temp_val_beg = nullptr;
303-
uint64_t temp1 = temp_val_unaligned % alignof(ValT);
304-
if (temp1)
305-
temp_val_beg =
306-
reinterpret_cast<ValT *>(temp_val_unaligned + alignof(ValT) - temp1);
307-
else
308-
temp_val_beg = reinterpret_cast<ValT *>(temp_val_unaligned);
299+
uint64_t temp_val_beg = 0, temp_key_beg = 0, internal_scratch = 0;
300+
KeyT *keys_ptr = nullptr;
301+
ValT *vals_ptr = nullptr;
302+
uint8_t *scratch_ptr = nullptr;
303+
304+
if (local_idx == 0) {
305+
uint64_t temp_val_unaligned =
306+
reinterpret_cast<uint64_t>(scratch + 2 * wg_size * n * sizeof(KeyT));
307+
uint64_t temp1 = temp_val_unaligned % alignof(ValT);
308+
temp_val_beg = (temp1 != 0) ? (temp_val_unaligned + alignof(ValT) - temp1)
309+
: temp_val_unaligned;
310+
temp_val_beg += sizeof(ValT) * n * wg_size;
311+
temp_key_beg = reinterpret_cast<uint64_t>(scratch);
312+
internal_scratch = temp_key_beg + sizeof(KeyT) * n * wg_size;
313+
}
309314

310-
uint8_t *internal_scratch =
311-
reinterpret_cast<uint8_t *>(&temp_key_beg[n * wg_size]);
312-
temp_val_beg = &temp_val_beg[n * wg_size];
315+
keys_ptr = reinterpret_cast<KeyT *>(group_broadcast(temp_key_beg));
316+
vals_ptr = reinterpret_cast<ValT *>(group_broadcast(temp_val_beg));
317+
scratch_ptr = reinterpret_cast<uint8_t *>(group_broadcast(internal_scratch));
313318

314319
for (size_t i = 0; i < n; ++i) {
315-
temp_key_beg[local_idx * n + i] = keys[i];
316-
temp_val_beg[local_idx * n + i] = vals[i];
320+
keys_ptr[local_idx * n + i] = keys[i];
321+
vals_ptr[local_idx * n + i] = vals[i];
317322
}
318323

319324
group_barrier();
320325

321-
merge_sort_key_value(temp_key_beg, temp_val_beg, n * wg_size,
322-
internal_scratch, comp);
326+
merge_sort_key_value(keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp);
323327

324328
for (size_t i = 0; i < n; ++i) {
325-
keys[i] = temp_key_beg[local_idx * n + i];
326-
vals[i] = temp_val_beg[local_idx * n + i];
329+
keys[i] = keys_ptr[local_idx * n + i];
330+
vals[i] = vals_ptr[local_idx * n + i];
327331
}
328332
}
329333

libdevice/spirv_decls.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,6 @@ extern DEVICE_EXTERNAL void
8383
__spirv_AtomicStore(int *, __spv::Scope::Flag, __spv::MemorySemanticsMask::Flag,
8484
int);
8585

86+
extern DEVICE_EXTERNAL uint64_t __spirv_GroupBroadcast(__spv::Scope::Flag,
87+
uint64_t, uint64_t);
8688
#endif // __SPIR__ || __SPIRV__

0 commit comments

Comments
 (0)