3838#include < sycl/sycl.hpp>
3939
4040#include " kernels/dpctl_tensor_types.hpp"
41+ #include " kernels/sorting/sort_utils.hpp"
4142#include " utils/sycl_alloc_utils.hpp"
4243
4344namespace dpctl
@@ -62,6 +63,47 @@ class radix_sort_reorder_peer_kernel;
6263template <std::uint32_t , bool , typename ... TrailingNames>
6364class radix_sort_reorder_kernel ;
6465
66+ /* ! @brief Computes smallest exponent such that `n <= (1 << exponent)` */
67+ template <typename SizeT,
68+ std::enable_if_t <std::is_unsigned_v<SizeT> &&
69+ sizeof (SizeT) == sizeof (std::uint64_t ),
70+ int > = 0 >
71+ std::uint32_t ceil_log2 (SizeT n)
72+ {
73+ if (n <= 1 )
74+ return std::uint32_t {1 };
75+
76+ std::uint32_t exp{1 };
77+ --n;
78+ // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
79+ // ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1)
80+ if (n >= (SizeT{1 } << 32 )) {
81+ n >>= 32 ;
82+ exp += 32 ;
83+ }
84+ if (n >= (SizeT{1 } << 16 )) {
85+ n >>= 16 ;
86+ exp += 16 ;
87+ }
88+ if (n >= (SizeT{1 } << 8 )) {
89+ n >>= 8 ;
90+ exp += 8 ;
91+ }
92+ if (n >= (SizeT{1 } << 4 )) {
93+ n >>= 4 ;
94+ exp += 4 ;
95+ }
96+ if (n >= (SizeT{1 } << 2 )) {
97+ n >>= 2 ;
98+ exp += 2 ;
99+ }
100+ if (n >= (SizeT{1 } << 1 )) {
101+ n >>= 1 ;
102+ ++exp;
103+ }
104+ return exp;
105+ }
106+
65107// ----------------------------------------------------------
66108// bitwise order-preserving conversions to unsigned integers
67109// ----------------------------------------------------------
@@ -1144,7 +1186,7 @@ struct subgroup_radix_sort
11441186 const std::size_t max_slm_size =
11451187 dev.template get_info <sycl::info::device::local_mem_size>() / 2 ;
11461188
1147- const auto n_uniform = 1 << ( std::uint32_t ( std::log2 (n - 1 )) + 1 );
1189+ const auto n_uniform = 1 << ceil_log2 (n );
11481190 const auto req_slm_size_val = sizeof (T) * n_uniform;
11491191
11501192 return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size)
@@ -1256,9 +1298,7 @@ struct subgroup_radix_sort
12561298 const uint16_t id = wi * block_size + i;
12571299 if (id < n)
12581300 values[i] = std::move (
1259- this_input_arr[iter_val_offset +
1260- static_cast <std::size_t >(
1261- id)]);
1301+ this_input_arr[iter_val_offset + id]);
12621302 }
12631303
12641304 while (true ) {
@@ -1272,8 +1312,7 @@ struct subgroup_radix_sort
12721312 // counting phase
12731313 auto pcounter =
12741314 get_accessor_pointer (counter_acc) +
1275- static_cast <std::size_t >(wi) +
1276- iter_counter_offset;
1315+ (wi + iter_counter_offset);
12771316
12781317// initialize counters
12791318#pragma unroll
@@ -1348,19 +1387,15 @@ struct subgroup_radix_sort
13481387
13491388 // scan contiguous numbers
13501389 uint16_t bin_sum[bin_count];
1351- bin_sum[0 ] =
1352- counter_acc[iter_counter_offset +
1353- static_cast <std::size_t >(
1354- wi * bin_count)];
1390+ const std::size_t counter_offset0 =
1391+ iter_counter_offset + wi * bin_count;
1392+ bin_sum[0 ] = counter_acc[counter_offset0];
13551393
13561394#pragma unroll
13571395 for (uint16_t i = 1 ; i < bin_count; ++i)
13581396 bin_sum[i] =
13591397 bin_sum[i - 1 ] +
1360- counter_acc
1361- [iter_counter_offset +
1362- static_cast <std::size_t >(
1363- wi * bin_count + i)];
1398+ counter_acc[counter_offset0 + i];
13641399
13651400 sycl::group_barrier (ndit.get_group ());
13661401
@@ -1374,10 +1409,7 @@ struct subgroup_radix_sort
13741409// add to local sum, generate exclusive scan result
13751410#pragma unroll
13761411 for (uint16_t i = 0 ; i < bin_count; ++i)
1377- counter_acc[iter_counter_offset +
1378- static_cast <std::size_t >(
1379- wi * bin_count + i +
1380- 1 )] =
1412+ counter_acc[counter_offset0 + i + 1 ] =
13811413 sum_scan + bin_sum[i];
13821414
13831415 if (wi == 0 )
@@ -1407,10 +1439,8 @@ struct subgroup_radix_sort
14071439 if (r < n) {
14081440 // move the values to source range and
14091441 // destroy the values
1410- this_output_arr
1411- [iter_val_offset +
1412- static_cast <std::size_t >(r)] =
1413- std::move (values[i]);
1442+ this_output_arr[iter_val_offset + r] =
1443+ std::move (values[i]);
14141444 }
14151445 }
14161446
@@ -1422,8 +1452,7 @@ struct subgroup_radix_sort
14221452 for (uint16_t i = 0 ; i < block_size; ++i) {
14231453 const uint16_t r = indices[i];
14241454 if (r < n)
1425- exchange_acc[iter_exchange_offset +
1426- static_cast <std::size_t >(r)] =
1455+ exchange_acc[iter_exchange_offset + r] =
14271456 std::move (values[i]);
14281457 }
14291458
@@ -1435,8 +1464,7 @@ struct subgroup_radix_sort
14351464 if (id < n)
14361465 values[i] = std::move (
14371466 exchange_acc[iter_exchange_offset +
1438- static_cast <std::size_t >(
1439- id)]);
1467+ id]);
14401468 }
14411469
14421470 sycl::group_barrier (ndit.get_group ());
@@ -1601,11 +1629,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16011629 using CountT = std::uint32_t ;
16021630
16031631 // memory for storing count and offset values
1604- CountT *count_ptr =
1605- sycl::malloc_device <CountT>(n_iters * n_counts, exec_q);
1606- if ( nullptr == count_ptr) {
1607- throw std::runtime_error ( " Could not allocate USM-device memory " );
1608- }
1632+ auto count_owner =
1633+ dpctl::tensor::alloc_utils::smart_malloc_device <CountT>(
1634+ n_iters * n_counts, exec_q);
1635+
1636+ CountT *count_ptr = count_owner. get ();
16091637
16101638 constexpr std::uint32_t zero_radix_iter{0 };
16111639
@@ -1618,25 +1646,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16181646 n_counts, count_ptr, proj_op,
16191647 is_ascending, depends);
16201648
1621- sort_ev = exec_q.submit ([=](sycl::handler &cgh) {
1622- cgh.depends_on (sort_ev);
1623- const sycl::context &ctx = exec_q.get_context ();
1624-
1625- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1626- cgh.host_task (
1627- [ctx, count_ptr]() { sycl_free_noexcept (count_ptr, ctx); });
1628- });
1649+ sort_ev = dpctl::tensor::alloc_utils::async_smart_free (
1650+ exec_q, {sort_ev}, count_owner);
16291651
16301652 return sort_ev;
16311653 }
16321654
1633- ValueT *tmp_arr =
1634- sycl::malloc_device<ValueT>(n_iters * n_to_sort, exec_q);
1635- if (nullptr == tmp_arr) {
1636- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1637- sycl_free_noexcept (count_ptr, exec_q);
1638- throw std::runtime_error (" Could not allocate USM-device memory" );
1639- }
1655+ auto tmp_arr_owner =
1656+ dpctl::tensor::alloc_utils::smart_malloc_device<ValueT>(
1657+ n_iters * n_to_sort, exec_q);
1658+
1659+ ValueT *tmp_arr = tmp_arr_owner.get ();
16401660
16411661 // iterations per each bucket
16421662 assert (" Number of iterations must be even" && radix_iters % 2 == 0 );
@@ -1670,17 +1690,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16701690 }
16711691 }
16721692
1673- sort_ev = exec_q.submit ([=](sycl::handler &cgh) {
1674- cgh.depends_on (sort_ev);
1675-
1676- const sycl::context &ctx = exec_q.get_context ();
1677-
1678- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1679- cgh.host_task ([ctx, count_ptr, tmp_arr]() {
1680- sycl_free_noexcept (tmp_arr, ctx);
1681- sycl_free_noexcept (count_ptr, ctx);
1682- });
1683- });
1693+ sort_ev = dpctl::tensor::alloc_utils::async_smart_free (
1694+ exec_q, {sort_ev}, tmp_arr_owner, count_owner);
16841695 }
16851696
16861697 return sort_ev;
@@ -1782,57 +1793,38 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17821793 reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
17831794
17841795 const std::size_t total_nelems = iter_nelems * sort_nelems;
1785- const std:: size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
1786- IndexTy *workspace = sycl::malloc_device <IndexTy>(
1787- padded_total_nelems + total_nelems, exec_q);
1796+ auto workspace_owner =
1797+ dpctl::tensor::alloc_utils::smart_malloc_device <IndexTy>(total_nelems,
1798+ exec_q);
17881799
1789- if (nullptr == workspace) {
1790- throw std::runtime_error (" Could not allocate workspace on device" );
1791- }
1800+ // get raw USM pointer
1801+ IndexTy *workspace = workspace_owner.get ();
17921802
17931803 using IdentityProjT = radix_sort_details::IdentityProj;
17941804 using IndexedProjT =
17951805 radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
17961806 const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
17971807
1798- sycl::event iota_ev = exec_q.submit ([&](sycl::handler &cgh) {
1799- cgh.depends_on (depends);
1808+ using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;
18001809
1801- using KernelName = radix_argsort_iota_krn<argTy, IndexTy> ;
1810+ using dpctl::tensor::kernels::sort_utils_detail::iota_impl ;
18021811
1803- cgh.parallel_for <KernelName>(
1804- sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
1805- size_t i = id[0 ];
1806- IndexTy sort_id = static_cast <IndexTy>(i);
1807- workspace[i] = sort_id;
1808- });
1809- });
1812+ sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
1813+ exec_q, workspace, total_nelems, depends);
18101814
18111815 sycl::event radix_sort_ev =
18121816 radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
18131817 exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op,
18141818 sort_ascending, {iota_ev});
18151819
1816- sycl::event map_back_ev = exec_q.submit ([&](sycl::handler &cgh) {
1817- cgh.depends_on (radix_sort_ev);
1818-
1819- using KernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1820-
1821- cgh.parallel_for <KernelName>(
1822- sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
1823- IndexTy linear_index = res_tp[id];
1824- res_tp[id] = (linear_index % sort_nelems);
1825- });
1826- });
1827-
1828- sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
1829- cgh.depends_on (map_back_ev);
1820+ using MapBackKernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1821+ using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
18301822
1831- const sycl::context &ctx = exec_q.get_context ();
1823+ sycl::event map_back_ev = map_back_impl<MapBackKernelName, IndexTy>(
1824+ exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev});
18321825
1833- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1834- cgh.host_task ([ctx, workspace] { sycl_free_noexcept (workspace, ctx); });
1835- });
1826+ sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
1827+ exec_q, {map_back_ev}, workspace_owner);
18361828
18371829 return cleanup_ev;
18381830}
0 commit comments