2727
2828#pragma once
2929
30+ #include < array>
3031#include < cstdint>
3132#include < limits>
3233#include < stdexcept>
@@ -477,10 +478,10 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
477478 sycl::access::address_space::local_space>;
478479 using TempStorageT = sycl::local_accessor<std::uint32_t , 1 >;
479480
480- sycl::sub_group sgroup;
481- std::uint32_t lid;
482- std::uint32_t item_mask;
483- AtomicT atomic_peer_mask;
481+ const sycl::sub_group sgroup;
482+ const std::uint32_t lid;
483+ const std::uint32_t item_mask;
484+ const AtomicT atomic_peer_mask;
484485
485486 peer_prefix_helper (sycl::nd_item<1 > ndit, TempStorageT lacc)
486487 : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
@@ -490,7 +491,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::atomic_fetch_or>
490491
491492 std::uint32_t peer_contribution (OffsetT &new_offset_id,
492493 OffsetT offset_prefix,
493- bool wi_bit_set)
494+ bool wi_bit_set) const
494495 {
495496 // reset mask for each radix state
496497 if (lid == 0 )
@@ -523,8 +524,8 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
523524 using ItemType = sycl::nd_item<1 >;
524525 using SubGroupType = sycl::sub_group;
525526
526- SubGroupType sgroup;
527- std::uint32_t sg_size;
527+ const SubGroupType sgroup;
528+ const std::uint32_t sg_size;
528529
529530 peer_prefix_helper (sycl::nd_item<1 > ndit, TempStorageT)
530531 : sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0 ])
@@ -533,7 +534,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::scan_then_broadcast>
533534
534535 std::uint32_t peer_contribution (OffsetT &new_offset_id,
535536 OffsetT offset_prefix,
536- bool wi_bit_set)
537+ bool wi_bit_set) const
537538 {
538539 const std::uint32_t contrib{wi_bit_set ? std::uint32_t {1 }
539540 : std::uint32_t {0 }};
@@ -567,9 +568,9 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
567568public:
568569 using TempStorageT = empty_storage;
569570
570- sycl::sub_group sgroup;
571- std::uint32_t lid;
572- sycl::ext::oneapi::sub_group_mask item_sg_mask;
571+ const sycl::sub_group sgroup;
572+ const std::uint32_t lid;
573+ const sycl::ext::oneapi::sub_group_mask item_sg_mask;
573574
574575 peer_prefix_helper (sycl::nd_item<1 > ndit, TempStorageT)
575576 : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()),
@@ -580,7 +581,7 @@ struct peer_prefix_helper<OffsetT, peer_prefix_algo::subgroup_ballot>
580581
581582 std::uint32_t peer_contribution (OffsetT &new_offset_id,
582583 OffsetT offset_prefix,
583- bool wi_bit_set)
584+ bool wi_bit_set) const
584585 {
585586 // set local id's bit to 1 if the bucket value matches the radix state
586587 auto peer_mask = sycl::ext::oneapi::group_ballot (sgroup, wi_bit_set);
@@ -750,7 +751,7 @@ radix_sort_reorder_submit(sycl::queue &exec_q,
750751 const std::uint32_t tail_size = (seg_end - seg_start) % sg_size;
751752 seg_end -= tail_size;
752753
753- PeerHelper peer_prefix_hlp (ndit, peer_temp);
754+ const PeerHelper peer_prefix_hlp (ndit, peer_temp);
754755
755756 // find offsets for the same values within a segment and fill the
756757 // resulting buffer
@@ -967,8 +968,13 @@ struct parallel_radix_sort_iteration_step
967968
968969 // 3. Reorder Phase
969970 sycl::event reorder_ev{};
970- if (reorder_sg_size == 8 || reorder_sg_size == 16 ||
971- reorder_sg_size == 32 )
971+ // subgroup_ballot-based peer algo uses extract_bits to populate
972+ // uint32_t mask and hence relies on sub-group to be 32 or narrower
973+ constexpr std::size_t sg32_v = 32u ;
974+ constexpr std::size_t sg16_v = 16u ;
975+ constexpr std::size_t sg08_v = 8u ;
976+ if (sg32_v == reorder_sg_size || sg16_v == reorder_sg_size ||
977+ sg08_v == reorder_sg_size)
972978 {
973979 constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot;
974980
0 commit comments