@@ -70,13 +70,14 @@ template <typename SizeT,
7070 int > = 0 >
7171std::uint32_t ceil_log2 (SizeT n)
7272{
73+ // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
74+ // floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1)
75+ // ceil_log2(n) == 1 + floor_log2(n-1)
7376 if (n <= 1 )
7477 return std::uint32_t {1 };
7578
7679 std::uint32_t exp{1 };
7780 --n;
78- // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
79- // ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1)
8081 if (n >= (SizeT{1 } << 32 )) {
8182 n >>= 32 ;
8283 exp += 32 ;
@@ -137,16 +138,20 @@ template <bool is_ascending,
137138std::make_unsigned_t <IntT> order_preserving_cast (IntT val)
138139{
139140 using UIntT = std::make_unsigned_t <IntT>;
140- // ascending_mask: 100..0
141- constexpr UIntT ascending_mask =
142- (UIntT (1 ) << std::numeric_limits<IntT>::digits);
143- // descending_mask: 011..1
144- constexpr UIntT descending_mask = (std::numeric_limits<UIntT>::max () >> 1 );
145-
146- constexpr UIntT mask = (is_ascending) ? ascending_mask : descending_mask;
147141 const UIntT uint_val = sycl::bit_cast<UIntT>(val);
148142
149- return (uint_val ^ mask);
143+ if constexpr (is_ascending) {
144+ // ascending_mask: 100..0
145+ constexpr UIntT ascending_mask =
146+ (UIntT (1 ) << std::numeric_limits<IntT>::digits);
147+ return (uint_val ^ ascending_mask);
148+ }
149+ else {
150+ // descending_mask: 011..1
151+ constexpr UIntT descending_mask =
152+ (std::numeric_limits<UIntT>::max () >> 1 );
153+ return (uint_val ^ descending_mask);
154+ }
150155}
151156
152157template <bool is_ascending> std::uint16_t order_preserving_cast (sycl::half val)
@@ -1045,10 +1050,10 @@ template <typename Names, std::uint16_t... Constants>
10451050class radix_sort_one_wg_krn ;
10461051
10471052template <typename KernelNameBase,
1048- uint16_t wg_size = 256 ,
1049- uint16_t block_size = 16 ,
1053+ std:: uint16_t wg_size = 256 ,
1054+ std:: uint16_t block_size = 16 ,
10501055 std::uint32_t radix = 4 ,
1051- uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16 )>
1056+ std:: uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16 )>
10521057struct subgroup_radix_sort
10531058{
10541059private:
@@ -1062,8 +1067,8 @@ struct subgroup_radix_sort
10621067public:
10631068 template <typename ValueT, typename OutputT, typename ProjT>
10641069 sycl::event operator ()(sycl::queue &exec_q,
1065- size_t n_iters,
1066- size_t n_to_sort,
1070+ std:: size_t n_iters,
1071+ std:: size_t n_to_sort,
10671072 ValueT *input_ptr,
10681073 OutputT *output_ptr,
10691074 ProjT proj_op,
@@ -1160,8 +1165,8 @@ struct subgroup_radix_sort
11601165 };
11611166
11621167 static_assert (wg_size <= 1024 );
1163- static constexpr uint16_t bin_count = (1 << radix);
1164- static constexpr uint16_t counter_buf_sz = wg_size * bin_count + 1 ;
1168+ static constexpr std:: uint16_t bin_count = (1 << radix);
1169+ static constexpr std:: uint16_t counter_buf_sz = wg_size * bin_count + 1 ;
11651170
11661171 enum class temp_allocations
11671172 {
@@ -1177,7 +1182,7 @@ struct subgroup_radix_sort
11771182 assert (n <= (SizeT (1 ) << 16 ));
11781183
11791184 constexpr auto req_slm_size_counters =
1180- counter_buf_sz * sizeof (uint32_t );
1185+ counter_buf_sz * sizeof (std:: uint16_t );
11811186
11821187 const auto &dev = exec_q.get_device ();
11831188
@@ -1212,9 +1217,9 @@ struct subgroup_radix_sort
12121217 typename SLM_value_tag,
12131218 typename SLM_counter_tag>
12141219 sycl::event operator ()(sycl::queue &exec_q,
1215- size_t n_iters,
1216- size_t n_batch_size,
1217- size_t n_values,
1220+ std:: size_t n_iters,
1221+ std:: size_t n_batch_size,
1222+ std:: size_t n_values,
12181223 InputT *input_arr,
12191224 OutputT *output_arr,
12201225 const ProjT &proj_op,
@@ -1228,7 +1233,7 @@ struct subgroup_radix_sort
12281233 assert (n_values <= static_cast <std::size_t >(block_size) *
12291234 static_cast <std::size_t >(wg_size));
12301235
1231- uint16_t n = static_cast <uint16_t >(n_values);
1236+ const std:: uint16_t n = static_cast <std:: uint16_t >(n_values);
12321237 static_assert (std::is_same_v<std::remove_cv_t <InputT>, OutputT>);
12331238
12341239 using ValueT = OutputT;
@@ -1237,17 +1242,18 @@ struct subgroup_radix_sort
12371242
12381243 TempBuf<ValueT, SLM_value_tag> buf_val (
12391244 n_batch_size, static_cast <std::size_t >(block_size * wg_size));
1240- TempBuf<std::uint32_t , SLM_counter_tag> buf_count (
1245+ TempBuf<std::uint16_t , SLM_counter_tag> buf_count (
12411246 n_batch_size, static_cast <std::size_t >(counter_buf_sz));
12421247
12431248 sycl::range<1 > lRange{wg_size};
12441249
12451250 sycl::event sort_ev;
1246- std::vector<sycl::event> deps = depends;
1251+ std::vector<sycl::event> deps{ depends} ;
12471252
1248- std::size_t n_batches = (n_iters + n_batch_size - 1 ) / n_batch_size;
1253+ const std::size_t n_batches =
1254+ (n_iters + n_batch_size - 1 ) / n_batch_size;
12491255
1250- for (size_t batch_id = 0 ; batch_id < n_batches; ++batch_id) {
1256+ for (std:: size_t batch_id = 0 ; batch_id < n_batches; ++batch_id) {
12511257
12521258 const std::size_t block_start = batch_id * n_batch_size;
12531259
@@ -1286,46 +1292,49 @@ struct subgroup_radix_sort
12861292 const std::size_t iter_exchange_offset =
12871293 iter_id * exchange_acc_iter_stride;
12881294
1289- uint16_t wi = ndit.get_local_linear_id ();
1290- uint16_t begin_bit = 0 ;
1295+ std:: uint16_t wi = ndit.get_local_linear_id ();
1296+ std:: uint16_t begin_bit = 0 ;
12911297
1292- constexpr uint16_t end_bit =
1298+ constexpr std:: uint16_t end_bit =
12931299 number_of_bits_in_type<KeyT>();
12941300
1295- // copy from input array into values
1301+ // copy from input array into values
12961302#pragma unroll
1297- for (uint16_t i = 0 ; i < block_size; ++i) {
1298- const uint16_t id = wi * block_size + i;
1299- if (id < n)
1300- values[i] = std::move (
1301- this_input_arr[iter_val_offset + id]) ;
1303+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1304+ const std:: uint16_t id = wi * block_size + i;
1305+ values[i] =
1306+ (id < n) ? this_input_arr[iter_val_offset + id]
1307+ : ValueT{} ;
13021308 }
13031309
13041310 while (true ) {
13051311 // indices for indirect access in the "re-order"
13061312 // phase
1307- uint16_t indices[block_size];
1313+ std:: uint16_t indices[block_size];
13081314 {
13091315 // pointers to bucket's counters
1310- uint32_t *counters[block_size];
1316+ std:: uint16_t *counters[block_size];
13111317
13121318 // counting phase
13131319 auto pcounter =
13141320 get_accessor_pointer (counter_acc) +
13151321 (wi + iter_counter_offset);
13161322
1317- // initialize counters
1323+ // initialize counters
13181324#pragma unroll
1319- for (uint16_t i = 0 ; i < bin_count; ++i)
1320- pcounter[i * wg_size] = std::uint32_t {0 };
1325+ for (std:: uint16_t i = 0 ; i < bin_count; ++i)
1326+ pcounter[i * wg_size] = std::uint16_t {0 };
13211327
13221328 sycl::group_barrier (ndit.get_group ());
13231329
13241330 if (is_ascending) {
13251331#pragma unroll
1326- for (uint16_t i = 0 ; i < block_size; ++i) {
1327- const uint16_t id = wi * block_size + i;
1328- constexpr uint16_t bin_mask =
1332+ for (std::uint16_t i = 0 ; i < block_size;
1333+ ++i)
1334+ {
1335+ const std::uint16_t id =
1336+ wi * block_size + i;
1337+ constexpr std::uint16_t bin_mask =
13291338 bin_count - 1 ;
13301339
13311340 // points to the padded element, i.e. id
@@ -1334,7 +1343,7 @@ struct subgroup_radix_sort
13341343 default_out_of_range_bin_id =
13351344 bin_mask;
13361345
1337- const uint16_t bin =
1346+ const std:: uint16_t bin =
13381347 (id < n)
13391348 ? get_bucket_id<bin_mask>(
13401349 order_preserving_cast<
@@ -1352,9 +1361,12 @@ struct subgroup_radix_sort
13521361 }
13531362 else {
13541363#pragma unroll
1355- for (uint16_t i = 0 ; i < block_size; ++i) {
1356- const uint16_t id = wi * block_size + i;
1357- constexpr uint16_t bin_mask =
1364+ for (std::uint16_t i = 0 ; i < block_size;
1365+ ++i)
1366+ {
1367+ const std::uint16_t id =
1368+ wi * block_size + i;
1369+ constexpr std::uint16_t bin_mask =
13581370 bin_count - 1 ;
13591371
13601372 // points to the padded element, i.e. id
@@ -1363,7 +1375,7 @@ struct subgroup_radix_sort
13631375 default_out_of_range_bin_id =
13641376 bin_mask;
13651377
1366- const uint16_t bin =
1378+ const std:: uint16_t bin =
13671379 (id < n)
13681380 ? get_bucket_id<bin_mask>(
13691381 order_preserving_cast<
@@ -1386,29 +1398,31 @@ struct subgroup_radix_sort
13861398 {
13871399
13881400 // scan contiguous numbers
1389- uint16_t bin_sum[bin_count];
1401+ std:: uint16_t bin_sum[bin_count];
13901402 const std::size_t counter_offset0 =
13911403 iter_counter_offset + wi * bin_count;
13921404 bin_sum[0 ] = counter_acc[counter_offset0];
13931405
13941406#pragma unroll
1395- for (uint16_t i = 1 ; i < bin_count; ++i)
1407+ for (std::uint16_t i = 1 ; i < bin_count;
1408+ ++i)
13961409 bin_sum[i] =
13971410 bin_sum[i - 1 ] +
13981411 counter_acc[counter_offset0 + i];
13991412
14001413 sycl::group_barrier (ndit.get_group ());
14011414
14021415 // exclusive scan local sum
1403- uint16_t sum_scan =
1416+ std:: uint16_t sum_scan =
14041417 sycl::exclusive_scan_over_group (
14051418 ndit.get_group (),
14061419 bin_sum[bin_count - 1 ],
1407- sycl::plus<uint16_t >());
1420+ sycl::plus<std:: uint16_t >());
14081421
14091422// add to local sum, generate exclusive scan result
14101423#pragma unroll
1411- for (uint16_t i = 0 ; i < bin_count; ++i)
1424+ for (std::uint16_t i = 0 ; i < bin_count;
1425+ ++i)
14121426 counter_acc[counter_offset0 + i + 1 ] =
14131427 sum_scan + bin_sum[i];
14141428
@@ -1420,51 +1434,50 @@ struct subgroup_radix_sort
14201434 }
14211435
14221436#pragma unroll
1423- for (uint16_t i = 0 ; i < block_size; ++i) {
1437+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
14241438 // a global index is a local offset plus a
14251439 // global base index
14261440 indices[i] += *counters[i];
14271441 }
1442+
1443+ sycl::group_barrier (ndit.get_group ());
14281444 }
14291445
14301446 begin_bit += radix;
14311447
14321448 // "re-order" phase
14331449 sycl::group_barrier (ndit.get_group ());
14341450 if (begin_bit >= end_bit) {
1435- // the last iteration - writing out the result
1451+ // the last iteration - writing out the result
14361452#pragma unroll
1437- for (uint16_t i = 0 ; i < block_size; ++i) {
1438- const uint16_t r = indices[i];
1453+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1454+ const std:: uint16_t r = indices[i];
14391455 if (r < n) {
1440- // move the values to source range and
1441- // destroy the values
14421456 this_output_arr[iter_val_offset + r] =
1443- std::move ( values[i]) ;
1457+ values[i];
14441458 }
14451459 }
14461460
14471461 return ;
14481462 }
14491463
1450- // data exchange
1464+ // data exchange
14511465#pragma unroll
1452- for (uint16_t i = 0 ; i < block_size; ++i) {
1453- const uint16_t r = indices[i];
1466+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1467+ const std:: uint16_t r = indices[i];
14541468 if (r < n)
14551469 exchange_acc[iter_exchange_offset + r] =
1456- std::move ( values[i]) ;
1470+ values[i];
14571471 }
14581472
14591473 sycl::group_barrier (ndit.get_group ());
14601474
14611475#pragma unroll
1462- for (uint16_t i = 0 ; i < block_size; ++i) {
1463- const uint16_t id = wi * block_size + i;
1476+ for (std:: uint16_t i = 0 ; i < block_size; ++i) {
1477+ const std:: uint16_t id = wi * block_size + i;
14641478 if (id < n)
1465- values[i] = std::move (
1466- exchange_acc[iter_exchange_offset +
1467- id]);
1479+ values[i] =
1480+ exchange_acc[iter_exchange_offset + id];
14681481 }
14691482
14701483 sycl::group_barrier (ndit.get_group ());
@@ -1736,10 +1749,10 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q,
17361749 const bool sort_ascending,
17371750 // number of sub-arrays to sort (num. of rows in a
17381751 // matrix when sorting over rows)
1739- size_t iter_nelems,
1752+ std:: size_t iter_nelems,
17401753 // size of each array to sort (length of rows,
17411754 // i.e. number of columns)
1742- size_t sort_nelems,
1755+ std:: size_t sort_nelems,
17431756 const char *arg_cp,
17441757 char *res_cp,
17451758 ssize_t iter_arg_offset,
@@ -1775,10 +1788,10 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17751788 const bool sort_ascending,
17761789 // number of sub-arrays to sort (num. of
17771790 // rows in a matrix when sorting over rows)
1778- size_t iter_nelems,
1791+ std:: size_t iter_nelems,
17791792 // size of each array to sort (length of
17801793 // rows, i.e. number of columns)
1781- size_t sort_nelems,
1794+ std:: size_t sort_nelems,
17821795 const char *arg_cp,
17831796 char *res_cp,
17841797 ssize_t iter_arg_offset,
0 commit comments