@@ -296,34 +296,38 @@ void private_merge_sort_key_value_close(KeyT *keys, ValT *vals, size_t n,
296296 uint8_t *scratch, Compare comp) {
297297 const size_t local_idx = __get_wg_local_linear_id ();
298298 const size_t wg_size = __get_wg_local_range ();
299- KeyT *temp_key_beg = reinterpret_cast <KeyT *>(scratch);
300- uint64_t temp_val_unaligned =
301- reinterpret_cast <uint64_t >(scratch + 2 * wg_size * n * sizeof (KeyT));
302- ValT *temp_val_beg = nullptr ;
303- uint64_t temp1 = temp_val_unaligned % alignof (ValT);
304- if (temp1)
305- temp_val_beg =
306- reinterpret_cast <ValT *>(temp_val_unaligned + alignof (ValT) - temp1);
307- else
308- temp_val_beg = reinterpret_cast <ValT *>(temp_val_unaligned);
299+ uint64_t temp_val_beg = 0 , temp_key_beg = 0 , internal_scratch = 0 ;
300+ KeyT *keys_ptr = nullptr ;
301+ ValT *vals_ptr = nullptr ;
302+ uint8_t *scratch_ptr = nullptr ;
303+
304+ if (local_idx == 0 ) {
305+ uint64_t temp_val_unaligned =
306+ reinterpret_cast <uint64_t >(scratch + 2 * wg_size * n * sizeof (KeyT));
307+ uint64_t temp1 = temp_val_unaligned % alignof (ValT);
308+ temp_val_beg = (temp1 != 0 ) ? (temp_val_unaligned + alignof (ValT) - temp1)
309+ : temp_val_unaligned;
310+ temp_val_beg += sizeof (ValT) * n * wg_size;
311+ temp_key_beg = reinterpret_cast <uint64_t >(scratch);
312+ internal_scratch = temp_key_beg + sizeof (KeyT) * n * wg_size;
313+ }
309314
310- uint8_t *internal_scratch =
311- reinterpret_cast <uint8_t *>(&temp_key_beg[n * wg_size] );
312- temp_val_beg = &temp_val_beg[n * wg_size] ;
315+ keys_ptr = reinterpret_cast <KeyT *>( group_broadcast (temp_key_beg));
316+ vals_ptr = reinterpret_cast <ValT *>(group_broadcast (temp_val_beg) );
317+ scratch_ptr = reinterpret_cast < uint8_t *>( group_broadcast (internal_scratch)) ;
313318
314319 for (size_t i = 0 ; i < n; ++i) {
315- temp_key_beg [local_idx * n + i] = keys[i];
316- temp_val_beg [local_idx * n + i] = vals[i];
320+ keys_ptr [local_idx * n + i] = keys[i];
321+ vals_ptr [local_idx * n + i] = vals[i];
317322 }
318323
319324 group_barrier ();
320325
321- merge_sort_key_value (temp_key_beg, temp_val_beg, n * wg_size,
322- internal_scratch, comp);
326+ merge_sort_key_value (keys_ptr, vals_ptr, n * wg_size, scratch_ptr, comp);
323327
324328 for (size_t i = 0 ; i < n; ++i) {
325- keys[i] = temp_key_beg [local_idx * n + i];
326- vals[i] = temp_val_beg [local_idx * n + i];
329+ keys[i] = keys_ptr [local_idx * n + i];
330+ vals[i] = vals_ptr [local_idx * n + i];
327331 }
328332}
329333
0 commit comments