@@ -292,8 +292,10 @@ void merge_sort_key_value(KeyT *keys, ValT *vals, size_t n, uint8_t *scratch,
292292// max(alignof(KeyT), alignof(ValT))
293293// The scrach memory alignment is max(alignof(KeyT), alignof(ValT))
294294template <typename KeyT, typename ValT, typename Compare>
295- void private_merge_sort_key_value_close (KeyT *keys, ValT *vals, size_t n,
296- uint8_t *scratch, Compare comp) {
295+ static void private_merge_sort_key_value_helper (KeyT *keys, ValT *vals,
296+ size_t n, uint8_t *scratch,
297+ Compare comp, KeyT **keys_back,
298+ ValT **vals_back) {
297299 const size_t local_idx = __get_wg_local_linear_id ();
298300 const size_t wg_size = __get_wg_local_range ();
299301 uint64_t temp_val_beg = 0 , temp_key_beg = 0 , internal_scratch = 0 ;
@@ -315,6 +317,8 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
315317 keys_ptr = reinterpret_cast <KeyT *>(group_broadcast (temp_key_beg));
316318 vals_ptr = reinterpret_cast <ValT *>(group_broadcast (temp_val_beg));
317319 scratch_ptr = reinterpret_cast <uint8_t *>(group_broadcast (internal_scratch));
320+ *keys_back = keys_ptr;
321+ *vals_back = vals_ptr;
318322
319323 for (size_t i = 0 ; i < n; ++i) {
320324 keys_ptr[local_idx * n + i] = keys[i];
@@ -324,10 +328,41 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
324328 group_barrier ();
325329
326330 merge_sort_key_value (keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp);
331+ }
332+
333+ template <typename KeyT, typename ValT, typename Compare>
334+ void private_merge_sort_key_value_close (KeyT *keys, ValT *vals, size_t n,
335+ uint8_t *scratch, Compare comp) {
336+
337+ KeyT *keys_back = nullptr ;
338+ ValT *vals_back = nullptr ;
339+ private_merge_sort_key_value_helper (keys, vals, n, scratch, comp, &keys_back,
340+ &vals_back);
341+
342+
343+ const size_t local_idx = __get_wg_local_linear_id ();
344+ for (size_t i = 0 ; i < n; ++i) {
345+ keys[i] = keys_back[local_idx * n + i];
346+ vals[i] = vals_back[local_idx * n + i];
347+ }
348+ }
349+
350+ template <typename KeyT, typename ValT, typename Compare>
351+ void private_merge_sort_key_value_spread (KeyT *keys, ValT *vals, size_t n,
352+ uint8_t *scratch, Compare comp) {
353+
354+ KeyT *keys_back = nullptr ;
355+ ValT *vals_back = nullptr ;
356+ private_merge_sort_key_value_helper (keys, vals, n, scratch, comp, &keys_back,
357+ &vals_back);
358+
359+
360+ const size_t local_idx = __get_wg_local_linear_id ();
361+ const size_t wg_size = __get_wg_local_range ();
327362
328363 for (size_t i = 0 ; i < n; ++i) {
329- keys[i] = keys_ptr[local_idx * n + i ];
330- vals[i] = vals_ptr[local_idx * n + i ];
364+ keys[i] = keys_back[wg_size * i + local_idx ];
365+ vals[i] = vals_back[wg_size * i + local_idx ];
331366 }
332367}
333368
0 commit comments