38
38
#include < sycl/sycl.hpp>
39
39
40
40
#include " kernels/dpctl_tensor_types.hpp"
41
+ #include " kernels/sorting/sort_utils.hpp"
41
42
#include " utils/sycl_alloc_utils.hpp"
42
43
43
44
namespace dpctl
@@ -62,6 +63,47 @@ class radix_sort_reorder_peer_kernel;
62
63
template <std::uint32_t , bool , typename ... TrailingNames>
63
64
class radix_sort_reorder_kernel ;
64
65
66
+ /* ! @brief Computes smallest exponent such that `n <= (1 << exponent)` */
67
+ template <typename SizeT,
68
+ std::enable_if_t <std::is_unsigned_v<SizeT> &&
69
+ sizeof (SizeT) == sizeof (std::uint64_t ),
70
+ int > = 0 >
71
+ std::uint32_t ceil_log2 (SizeT n)
72
+ {
73
+ if (n <= 1 )
74
+ return std::uint32_t {1 };
75
+
76
+ std::uint32_t exp{1 };
77
+ --n;
78
+ // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
79
+ // ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1)
80
+ if (n >= (SizeT{1 } << 32 )) {
81
+ n >>= 32 ;
82
+ exp += 32 ;
83
+ }
84
+ if (n >= (SizeT{1 } << 16 )) {
85
+ n >>= 16 ;
86
+ exp += 16 ;
87
+ }
88
+ if (n >= (SizeT{1 } << 8 )) {
89
+ n >>= 8 ;
90
+ exp += 8 ;
91
+ }
92
+ if (n >= (SizeT{1 } << 4 )) {
93
+ n >>= 4 ;
94
+ exp += 4 ;
95
+ }
96
+ if (n >= (SizeT{1 } << 2 )) {
97
+ n >>= 2 ;
98
+ exp += 2 ;
99
+ }
100
+ if (n >= (SizeT{1 } << 1 )) {
101
+ n >>= 1 ;
102
+ ++exp;
103
+ }
104
+ return exp;
105
+ }
106
+
65
107
// ----------------------------------------------------------
66
108
// bitwise order-preserving conversions to unsigned integers
67
109
// ----------------------------------------------------------
@@ -1144,7 +1186,7 @@ struct subgroup_radix_sort
1144
1186
const std::size_t max_slm_size =
1145
1187
dev.template get_info <sycl::info::device::local_mem_size>() / 2 ;
1146
1188
1147
- const auto n_uniform = 1 << ( std::uint32_t ( std::log2 (n - 1 )) + 1 );
1189
+ const auto n_uniform = 1 << ceil_log2 (n );
1148
1190
const auto req_slm_size_val = sizeof (T) * n_uniform;
1149
1191
1150
1192
return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size)
@@ -1256,9 +1298,7 @@ struct subgroup_radix_sort
1256
1298
const uint16_t id = wi * block_size + i;
1257
1299
if (id < n)
1258
1300
values[i] = std::move (
1259
- this_input_arr[iter_val_offset +
1260
- static_cast <std::size_t >(
1261
- id)]);
1301
+ this_input_arr[iter_val_offset + id]);
1262
1302
}
1263
1303
1264
1304
while (true ) {
@@ -1272,8 +1312,7 @@ struct subgroup_radix_sort
1272
1312
// counting phase
1273
1313
auto pcounter =
1274
1314
get_accessor_pointer (counter_acc) +
1275
- static_cast <std::size_t >(wi) +
1276
- iter_counter_offset;
1315
+ (wi + iter_counter_offset);
1277
1316
1278
1317
// initialize counters
1279
1318
#pragma unroll
@@ -1348,19 +1387,15 @@ struct subgroup_radix_sort
1348
1387
1349
1388
// scan contiguous numbers
1350
1389
uint16_t bin_sum[bin_count];
1351
- bin_sum[0 ] =
1352
- counter_acc[iter_counter_offset +
1353
- static_cast <std::size_t >(
1354
- wi * bin_count)];
1390
+ const std::size_t counter_offset0 =
1391
+ iter_counter_offset + wi * bin_count;
1392
+ bin_sum[0 ] = counter_acc[counter_offset0];
1355
1393
1356
1394
#pragma unroll
1357
1395
for (uint16_t i = 1 ; i < bin_count; ++i)
1358
1396
bin_sum[i] =
1359
1397
bin_sum[i - 1 ] +
1360
- counter_acc
1361
- [iter_counter_offset +
1362
- static_cast <std::size_t >(
1363
- wi * bin_count + i)];
1398
+ counter_acc[counter_offset0 + i];
1364
1399
1365
1400
sycl::group_barrier (ndit.get_group ());
1366
1401
@@ -1374,10 +1409,7 @@ struct subgroup_radix_sort
1374
1409
// add to local sum, generate exclusive scan result
1375
1410
#pragma unroll
1376
1411
for (uint16_t i = 0 ; i < bin_count; ++i)
1377
- counter_acc[iter_counter_offset +
1378
- static_cast <std::size_t >(
1379
- wi * bin_count + i +
1380
- 1 )] =
1412
+ counter_acc[counter_offset0 + i + 1 ] =
1381
1413
sum_scan + bin_sum[i];
1382
1414
1383
1415
if (wi == 0 )
@@ -1407,10 +1439,8 @@ struct subgroup_radix_sort
1407
1439
if (r < n) {
1408
1440
// move the values to source range and
1409
1441
// destroy the values
1410
- this_output_arr
1411
- [iter_val_offset +
1412
- static_cast <std::size_t >(r)] =
1413
- std::move (values[i]);
1442
+ this_output_arr[iter_val_offset + r] =
1443
+ std::move (values[i]);
1414
1444
}
1415
1445
}
1416
1446
@@ -1422,8 +1452,7 @@ struct subgroup_radix_sort
1422
1452
for (uint16_t i = 0 ; i < block_size; ++i) {
1423
1453
const uint16_t r = indices[i];
1424
1454
if (r < n)
1425
- exchange_acc[iter_exchange_offset +
1426
- static_cast <std::size_t >(r)] =
1455
+ exchange_acc[iter_exchange_offset + r] =
1427
1456
std::move (values[i]);
1428
1457
}
1429
1458
@@ -1435,8 +1464,7 @@ struct subgroup_radix_sort
1435
1464
if (id < n)
1436
1465
values[i] = std::move (
1437
1466
exchange_acc[iter_exchange_offset +
1438
- static_cast <std::size_t >(
1439
- id)]);
1467
+ id]);
1440
1468
}
1441
1469
1442
1470
sycl::group_barrier (ndit.get_group ());
@@ -1601,11 +1629,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
1601
1629
using CountT = std::uint32_t ;
1602
1630
1603
1631
// memory for storing count and offset values
1604
- CountT *count_ptr =
1605
- sycl::malloc_device <CountT>(n_iters * n_counts, exec_q);
1606
- if ( nullptr == count_ptr) {
1607
- throw std::runtime_error ( " Could not allocate USM-device memory " );
1608
- }
1632
+ auto count_owner =
1633
+ dpctl::tensor::alloc_utils::smart_malloc_device <CountT>(
1634
+ n_iters * n_counts, exec_q);
1635
+
1636
+ CountT *count_ptr = count_owner. get ();
1609
1637
1610
1638
constexpr std::uint32_t zero_radix_iter{0 };
1611
1639
@@ -1618,25 +1646,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
1618
1646
n_counts, count_ptr, proj_op,
1619
1647
is_ascending, depends);
1620
1648
1621
- sort_ev = exec_q.submit ([=](sycl::handler &cgh) {
1622
- cgh.depends_on (sort_ev);
1623
- const sycl::context &ctx = exec_q.get_context ();
1624
-
1625
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1626
- cgh.host_task (
1627
- [ctx, count_ptr]() { sycl_free_noexcept (count_ptr, ctx); });
1628
- });
1649
+ sort_ev = dpctl::tensor::alloc_utils::async_smart_free (
1650
+ exec_q, {sort_ev}, count_owner);
1629
1651
1630
1652
return sort_ev;
1631
1653
}
1632
1654
1633
- ValueT *tmp_arr =
1634
- sycl::malloc_device<ValueT>(n_iters * n_to_sort, exec_q);
1635
- if (nullptr == tmp_arr) {
1636
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1637
- sycl_free_noexcept (count_ptr, exec_q);
1638
- throw std::runtime_error (" Could not allocate USM-device memory" );
1639
- }
1655
+ auto tmp_arr_owner =
1656
+ dpctl::tensor::alloc_utils::smart_malloc_device<ValueT>(
1657
+ n_iters * n_to_sort, exec_q);
1658
+
1659
+ ValueT *tmp_arr = tmp_arr_owner.get ();
1640
1660
1641
1661
// iterations per each bucket
1642
1662
assert (" Number of iterations must be even" && radix_iters % 2 == 0 );
@@ -1670,17 +1690,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
1670
1690
}
1671
1691
}
1672
1692
1673
- sort_ev = exec_q.submit ([=](sycl::handler &cgh) {
1674
- cgh.depends_on (sort_ev);
1675
-
1676
- const sycl::context &ctx = exec_q.get_context ();
1677
-
1678
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1679
- cgh.host_task ([ctx, count_ptr, tmp_arr]() {
1680
- sycl_free_noexcept (tmp_arr, ctx);
1681
- sycl_free_noexcept (count_ptr, ctx);
1682
- });
1683
- });
1693
+ sort_ev = dpctl::tensor::alloc_utils::async_smart_free (
1694
+ exec_q, {sort_ev}, tmp_arr_owner, count_owner);
1684
1695
}
1685
1696
1686
1697
return sort_ev;
@@ -1782,57 +1793,38 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
1782
1793
reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
1783
1794
1784
1795
const std::size_t total_nelems = iter_nelems * sort_nelems;
1785
- const std:: size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
1786
- IndexTy *workspace = sycl::malloc_device <IndexTy>(
1787
- padded_total_nelems + total_nelems, exec_q);
1796
+ auto workspace_owner =
1797
+ dpctl::tensor::alloc_utils::smart_malloc_device <IndexTy>(total_nelems,
1798
+ exec_q);
1788
1799
1789
- if (nullptr == workspace) {
1790
- throw std::runtime_error (" Could not allocate workspace on device" );
1791
- }
1800
+ // get raw USM pointer
1801
+ IndexTy *workspace = workspace_owner.get ();
1792
1802
1793
1803
using IdentityProjT = radix_sort_details::IdentityProj;
1794
1804
using IndexedProjT =
1795
1805
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
1796
1806
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
1797
1807
1798
- sycl::event iota_ev = exec_q.submit ([&](sycl::handler &cgh) {
1799
- cgh.depends_on (depends);
1808
+ using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;
1800
1809
1801
- using KernelName = radix_argsort_iota_krn<argTy, IndexTy> ;
1810
+ using dpctl::tensor::kernels::sort_utils_detail::iota_impl ;
1802
1811
1803
- cgh.parallel_for <KernelName>(
1804
- sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
1805
- size_t i = id[0 ];
1806
- IndexTy sort_id = static_cast <IndexTy>(i);
1807
- workspace[i] = sort_id;
1808
- });
1809
- });
1812
+ sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
1813
+ exec_q, workspace, total_nelems, depends);
1810
1814
1811
1815
sycl::event radix_sort_ev =
1812
1816
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
1813
1817
exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op,
1814
1818
sort_ascending, {iota_ev});
1815
1819
1816
- sycl::event map_back_ev = exec_q.submit ([&](sycl::handler &cgh) {
1817
- cgh.depends_on (radix_sort_ev);
1818
-
1819
- using KernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1820
-
1821
- cgh.parallel_for <KernelName>(
1822
- sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
1823
- IndexTy linear_index = res_tp[id];
1824
- res_tp[id] = (linear_index % sort_nelems);
1825
- });
1826
- });
1827
-
1828
- sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
1829
- cgh.depends_on (map_back_ev);
1820
+ using MapBackKernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1821
+ using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
1830
1822
1831
- const sycl::context &ctx = exec_q.get_context ();
1823
+ sycl::event map_back_ev = map_back_impl<MapBackKernelName, IndexTy>(
1824
+ exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev});
1832
1825
1833
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1834
- cgh.host_task ([ctx, workspace] { sycl_free_noexcept (workspace, ctx); });
1835
- });
1826
+ sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
1827
+ exec_q, {map_back_ev}, workspace_owner);
1836
1828
1837
1829
return cleanup_ev;
1838
1830
}
0 commit comments