Skip to content

Commit 8a03ba5

Browse files
committed
add private KV spread sort
Signed-off-by: jinge90 <[email protected]>
1 parent 828be3b commit 8a03ba5

11 files changed

+2605
-857
lines changed

libdevice/fallback-gsort.cpp

Lines changed: 1042 additions & 0 deletions
Large diffs are not rendered by default.

libdevice/sort_helper.hpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,10 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch,
292292
// max(alignof(KeyT), alignof(ValT))
293293
// The scrach memory alignment is max(alignof(KeyT), alignof(ValT))
294294
template <typename KeyT, typename ValT, typename Compare>
295-
void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
296-
uint8_t *scratch, Compare comp) {
295+
static void private_merge_sort_key_value_helper(KeyT *keys, ValT *vals,
296+
size_t n, uint8_t *scratch,
297+
Compare comp, KeyT **keys_back,
298+
ValT **vals_back) {
297299
const size_t local_idx = __get_wg_local_linear_id();
298300
const size_t wg_size = __get_wg_local_range();
299301
uint64_t temp_val_beg = 0, temp_key_beg = 0, internal_scratch = 0;
@@ -315,6 +317,8 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
315317
keys_ptr = reinterpret_cast<KeyT *>(group_broadcast(temp_key_beg));
316318
vals_ptr = reinterpret_cast<ValT *>(group_broadcast(temp_val_beg));
317319
scratch_ptr = reinterpret_cast<uint8_t *>(group_broadcast(internal_scratch));
320+
*keys_back = keys_ptr;
321+
*vals_back = vals_ptr;
318322

319323
for (size_t i = 0; i < n; ++i) {
320324
keys_ptr[local_idx * n + i] = keys[i];
@@ -324,10 +328,41 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
324328
group_barrier();
325329

326330
merge_sort_key_value(keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp);
331+
}
332+
333+
template <typename KeyT, typename ValT, typename Compare>
334+
void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
335+
uint8_t *scratch, Compare comp) {
336+
337+
KeyT *keys_back = nullptr;
338+
ValT *vals_back = nullptr;
339+
private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back,
340+
&vals_back);
341+
342+
343+
const size_t local_idx = __get_wg_local_linear_id();
344+
for (size_t i = 0; i < n; ++i) {
345+
keys[i] = keys_back[local_idx * n + i];
346+
vals[i] = vals_back[local_idx * n + i];
347+
}
348+
}
349+
350+
template <typename KeyT, typename ValT, typename Compare>
351+
void private_merge_sort_key_value_spread(KeyT *keys, ValT *vals, size_t n,
352+
uint8_t *scratch, Compare comp) {
353+
354+
KeyT *keys_back = nullptr;
355+
ValT *vals_back = nullptr;
356+
private_merge_sort_key_value_helper(keys, vals, n, scratch, comp, &keys_back,
357+
&vals_back);
358+
359+
360+
const size_t local_idx = __get_wg_local_linear_id();
361+
const size_t wg_size = __get_wg_local_range();
327362

328363
for (size_t i = 0; i < n; ++i) {
329-
keys[i] = keys_ptr[local_idx * n + i];
330-
vals[i] = vals_ptr[local_idx * n + i];
364+
keys[i] = keys_back[wg_size * i + local_idx];
365+
vals[i] = vals_back[wg_size * i + local_idx];
331366
}
332367
}
333368

0 commit comments

Comments
 (0)