@@ -51,46 +51,43 @@ template <typename Group, size_t Extent> class group_with_scratchpad {
5151// ---- sorters
5252template <typename Compare = std::less<>> class default_sorter {
5353 Compare comp;
54- std::byte *scratch;
55- size_t scratch_size;
54+ sycl::span<std::byte> scratch;
5655
5756public:
5857 template <size_t Extent>
5958 default_sorter (sycl::span<std::byte, Extent> scratch_,
6059 Compare comp_ = Compare())
61- : comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size() ) {}
60+ : comp(comp_), scratch(scratch_) {}
6261
6362 template <typename Group, typename Ptr>
64- void operator ()(Group g, Ptr first, Ptr last) {
63+ void operator ()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
64+ [[maybe_unused]] Ptr last) {
6565#ifdef __SYCL_DEVICE_ONLY__
66- using T = typename sycl::detail::GetValueType<Ptr>::type;
67- if (scratch_size >= memory_required<T>(Group::fence_scope, last - first))
68- sycl::detail::merge_sort (g, first, last - first, comp, scratch);
69- // TODO: it's better to add else branch
66+ // Per extension specification if scratch size is less than the value
67+ // returned by memory_required then behavior is undefined, so we don't check
68+ // that the scratch size statisfies the requirement.
69+ sycl::detail::merge_sort (g, first, last - first, comp, scratch. data ());
7070#else
71- (void )g;
72- (void )first;
73- (void )last;
7471 throw sycl::exception (
7572 std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
7673 " default_sorter constructor is not supported on host device." );
7774#endif
7875 }
7976
80- template <typename Group, typename T> T operator ()(Group g, T val) {
77+ template <typename Group, typename T>
78+ T operator ()([[maybe_unused]] Group g, T val) {
8179#ifdef __SYCL_DEVICE_ONLY__
80+ // Per extension specification if scratch size is less than the value
81+ // returned by memory_required then behavior is undefined, so we don't check
82+ // that the scratch size statisfies the requirement.
8283 auto range_size = g.get_local_range ().size ();
83- if (scratch_size >= memory_required<T>(Group::fence_scope, range_size)) {
84- size_t local_id = g.get_local_linear_id ();
85- T *temp = reinterpret_cast <T *>(scratch);
86- ::new (temp + local_id) T (val);
87- sycl::detail::merge_sort (g, temp, range_size, comp,
88- scratch + range_size * sizeof (T));
89- val = temp[local_id];
90- }
91- // TODO: it's better to add else branch
84+ size_t local_id = g.get_local_linear_id ();
85+ T *temp = reinterpret_cast <T *>(scratch.data ());
86+ ::new (temp + local_id) T (val);
87+ sycl::detail::merge_sort (g, temp, range_size, comp,
88+ scratch.data () + range_size * sizeof (T));
89+ val = temp[local_id];
9290#else
93- (void )g;
9491 throw sycl::exception (
9592 std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
9693 " default_sorter operator() is not supported on host device." );
@@ -129,62 +126,56 @@ template <typename ValT, sorting_order OrderT = sorting_order::ascending,
129126 unsigned int BitsPerPass = 4 >
130127class radix_sorter {
131128
132- std::byte * scratch = nullptr ;
129+ sycl::span< std::byte> scratch;
133130 uint32_t first_bit = 0 ;
134131 uint32_t last_bit = 0 ;
135- size_t scratch_size = 0 ;
136132
137133 static constexpr uint32_t bits = BitsPerPass;
134+ using bitset_t = std::bitset<sizeof (ValT) * CHAR_BIT>;
138135
139136public:
140137 template <size_t Extent>
141138 radix_sorter (sycl::span<std::byte, Extent> scratch_,
142- const std::bitset<sizeof (ValT) *CHAR_BIT> mask =
143- std::bitset<sizeof (ValT) * CHAR_BIT>(
144- (std::numeric_limits<unsigned long long >::max)()))
145- : scratch(scratch_.data()), scratch_size(scratch_.size()) {
139+ const bitset_t mask = bitset_t {}.set())
140+ : scratch(scratch_) {
146141 static_assert ((std::is_arithmetic<ValT>::value ||
147142 std::is_same<ValT, sycl::half>::value ||
148143 std::is_same<ValT, sycl::ext::oneapi::bfloat16>::value),
149144 " radix sort is not usable" );
150145
151- first_bit = 0 ;
152- while (first_bit < mask.size () && !mask[first_bit])
153- ++first_bit;
154-
155- last_bit = first_bit;
156- while (last_bit < mask.size () && mask[last_bit])
157- ++last_bit;
146+ for (first_bit = 0 ; first_bit < mask.size () && !mask[first_bit];
147+ ++first_bit)
148+ ;
149+ for (last_bit = first_bit; last_bit < mask.size () && mask[last_bit];
150+ ++last_bit)
151+ ;
158152 }
159153
160154 template <typename GroupT, typename PtrT>
161- void operator ()(GroupT g, PtrT first, PtrT last) {
162- (void )g;
163- (void )first;
164- (void )last;
155+ void operator ()([[maybe_unused]] GroupT g, [[maybe_unused]] PtrT first,
156+ [[maybe_unused]] PtrT last) {
165157#ifdef __SYCL_DEVICE_ONLY__
166158 sycl::detail::privateDynamicSort</* is_key_value=*/ false ,
167159 OrderT == sorting_order::ascending,
168160 /* empty*/ 1 , BitsPerPass>(
169- g, first, /* empty*/ first, ( last - first) > 0 ? (last - first) : 0 ,
170- scratch, first_bit, last_bit);
161+ g, first, /* empty*/ first, last - first, scratch. data (), first_bit ,
162+ last_bit);
171163#else
172164 throw sycl::exception (
173165 std::error_code (PI_ERROR_INVALID_DEVICE, sycl::sycl_category ()),
174166 " radix_sorter is not supported on host device." );
175167#endif
176168 }
177169
178- template <typename GroupT> ValT operator ()(GroupT g, ValT val) {
179- (void )g;
180- (void )val;
170+ template <typename GroupT>
171+ ValT operator ()([[maybe_unused]] GroupT g, [[maybe_unused]] ValT val) {
181172#ifdef __SYCL_DEVICE_ONLY__
182173 ValT result[]{val};
183174 sycl::detail::privateStaticSort</* is_key_value=*/ false ,
184175 /* is_blocked=*/ true ,
185176 OrderT == sorting_order::ascending,
186177 /* items_per_work_item=*/ 1 , bits>(
187- g, result, /* empty*/ result, scratch, first_bit, last_bit);
178+ g, result, /* empty*/ result, scratch. data () , first_bit, last_bit);
188179 return result[0 ];
189180#else
190181 throw sycl::exception (
@@ -193,20 +184,16 @@ class radix_sorter {
193184#endif
194185 }
195186
196- static constexpr size_t memory_required (sycl::memory_scope scope ,
187+ static constexpr size_t memory_required (sycl::memory_scope,
197188 size_t range_size) {
198- // Scope is not important so far
199- (void )scope;
200189 return range_size * sizeof (ValT) +
201190 (1 << bits) * range_size * sizeof (uint32_t ) + alignof (uint32_t );
202191 }
203192
204193 // memory_helpers
205194 template <int dimensions = 1 >
206- static constexpr size_t memory_required (sycl::memory_scope scope ,
195+ static constexpr size_t memory_required (sycl::memory_scope,
207196 sycl::range<dimensions> local_range) {
208- // Scope is not important so far
209- (void )scope;
210197 return (std::max)(local_range.size () * sizeof (ValT),
211198 local_range.size () * (1 << bits) * sizeof (uint32_t ));
212199 }
0 commit comments