Skip to content

Commit e161ea1

Browse files
Merge pull request #1946 from IntelPython/add-iota-kernel
Add iota kernel
2 parents 8f38b80 + 4c8d7d7 commit e161ea1

File tree

6 files changed

+682
-244
lines changed

6 files changed

+682
-244
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "kernels/dpctl_tensor_types.hpp"
3535
#include "kernels/sorting/search_sorted_detail.hpp"
36+
#include "kernels/sorting/sort_utils.hpp"
3637

3738
namespace dpctl
3839
{
@@ -811,20 +812,12 @@ sycl::event stable_argsort_axis1_contig_impl(
811812

812813
const size_t total_nelems = iter_nelems * sort_nelems;
813814

814-
sycl::event populate_indexed_data_ev =
815-
exec_q.submit([&](sycl::handler &cgh) {
816-
cgh.depends_on(depends);
815+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
817816

818-
const sycl::range<1> range{total_nelems};
817+
using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;
819818

820-
using KernelName =
821-
populate_index_data_krn<argTy, IndexTy, ValueComp>;
822-
823-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
824-
size_t i = id[0];
825-
res_tp[i] = static_cast<IndexTy>(i);
826-
});
827-
});
819+
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
820+
exec_q, res_tp, total_nelems, depends);
828821

829822
// Sort segments of the array
830823
sycl::event base_sort_ev =
@@ -839,21 +832,11 @@ sycl::event stable_argsort_axis1_contig_impl(
839832
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
840833
{base_sort_ev});
841834

842-
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
843-
cgh.depends_on(merges_ev);
844-
845-
auto temp_acc =
846-
merge_sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp,
847-
cgh);
848-
849-
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
835+
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
836+
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
850837

851-
const sycl::range<1> range{total_nelems};
852-
853-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
854-
res_tp[id] = (temp_acc[id] % sort_nelems);
855-
});
856-
});
838+
sycl::event write_out_ev = map_back_impl<MapBackKernelName, IndexTy>(
839+
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev});
857840

858841
return write_out_ev;
859842
}

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 83 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <sycl/sycl.hpp>
3939

4040
#include "kernels/dpctl_tensor_types.hpp"
41+
#include "kernels/sorting/sort_utils.hpp"
4142
#include "utils/sycl_alloc_utils.hpp"
4243

4344
namespace dpctl
@@ -62,6 +63,47 @@ class radix_sort_reorder_peer_kernel;
6263
template <std::uint32_t, bool, typename... TrailingNames>
6364
class radix_sort_reorder_kernel;
6465

66+
/*! @brief Computes smallest exponent such that `n <= (1 << exponent)` */
67+
template <typename SizeT,
68+
std::enable_if_t<std::is_unsigned_v<SizeT> &&
69+
sizeof(SizeT) == sizeof(std::uint64_t),
70+
int> = 0>
71+
std::uint32_t ceil_log2(SizeT n)
72+
{
73+
if (n <= 1)
74+
return std::uint32_t{1};
75+
76+
std::uint32_t exp{1};
77+
--n;
78+
// if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b
79+
// ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1)
80+
if (n >= (SizeT{1} << 32)) {
81+
n >>= 32;
82+
exp += 32;
83+
}
84+
if (n >= (SizeT{1} << 16)) {
85+
n >>= 16;
86+
exp += 16;
87+
}
88+
if (n >= (SizeT{1} << 8)) {
89+
n >>= 8;
90+
exp += 8;
91+
}
92+
if (n >= (SizeT{1} << 4)) {
93+
n >>= 4;
94+
exp += 4;
95+
}
96+
if (n >= (SizeT{1} << 2)) {
97+
n >>= 2;
98+
exp += 2;
99+
}
100+
if (n >= (SizeT{1} << 1)) {
101+
n >>= 1;
102+
++exp;
103+
}
104+
return exp;
105+
}
106+
65107
//----------------------------------------------------------
66108
// bitwise order-preserving conversions to unsigned integers
67109
//----------------------------------------------------------
@@ -1144,7 +1186,7 @@ struct subgroup_radix_sort
11441186
const std::size_t max_slm_size =
11451187
dev.template get_info<sycl::info::device::local_mem_size>() / 2;
11461188

1147-
const auto n_uniform = 1 << (std::uint32_t(std::log2(n - 1)) + 1);
1189+
const auto n_uniform = 1 << ceil_log2(n);
11481190
const auto req_slm_size_val = sizeof(T) * n_uniform;
11491191

11501192
return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size)
@@ -1256,9 +1298,7 @@ struct subgroup_radix_sort
12561298
const uint16_t id = wi * block_size + i;
12571299
if (id < n)
12581300
values[i] = std::move(
1259-
this_input_arr[iter_val_offset +
1260-
static_cast<std::size_t>(
1261-
id)]);
1301+
this_input_arr[iter_val_offset + id]);
12621302
}
12631303

12641304
while (true) {
@@ -1272,8 +1312,7 @@ struct subgroup_radix_sort
12721312
// counting phase
12731313
auto pcounter =
12741314
get_accessor_pointer(counter_acc) +
1275-
static_cast<std::size_t>(wi) +
1276-
iter_counter_offset;
1315+
(wi + iter_counter_offset);
12771316

12781317
// initialize counters
12791318
#pragma unroll
@@ -1348,19 +1387,15 @@ struct subgroup_radix_sort
13481387

13491388
// scan contiguous numbers
13501389
uint16_t bin_sum[bin_count];
1351-
bin_sum[0] =
1352-
counter_acc[iter_counter_offset +
1353-
static_cast<std::size_t>(
1354-
wi * bin_count)];
1390+
const std::size_t counter_offset0 =
1391+
iter_counter_offset + wi * bin_count;
1392+
bin_sum[0] = counter_acc[counter_offset0];
13551393

13561394
#pragma unroll
13571395
for (uint16_t i = 1; i < bin_count; ++i)
13581396
bin_sum[i] =
13591397
bin_sum[i - 1] +
1360-
counter_acc
1361-
[iter_counter_offset +
1362-
static_cast<std::size_t>(
1363-
wi * bin_count + i)];
1398+
counter_acc[counter_offset0 + i];
13641399

13651400
sycl::group_barrier(ndit.get_group());
13661401

@@ -1374,10 +1409,7 @@ struct subgroup_radix_sort
13741409
// add to local sum, generate exclusive scan result
13751410
#pragma unroll
13761411
for (uint16_t i = 0; i < bin_count; ++i)
1377-
counter_acc[iter_counter_offset +
1378-
static_cast<std::size_t>(
1379-
wi * bin_count + i +
1380-
1)] =
1412+
counter_acc[counter_offset0 + i + 1] =
13811413
sum_scan + bin_sum[i];
13821414

13831415
if (wi == 0)
@@ -1407,10 +1439,8 @@ struct subgroup_radix_sort
14071439
if (r < n) {
14081440
// move the values to source range and
14091441
// destroy the values
1410-
this_output_arr
1411-
[iter_val_offset +
1412-
static_cast<std::size_t>(r)] =
1413-
std::move(values[i]);
1442+
this_output_arr[iter_val_offset + r] =
1443+
std::move(values[i]);
14141444
}
14151445
}
14161446

@@ -1422,8 +1452,7 @@ struct subgroup_radix_sort
14221452
for (uint16_t i = 0; i < block_size; ++i) {
14231453
const uint16_t r = indices[i];
14241454
if (r < n)
1425-
exchange_acc[iter_exchange_offset +
1426-
static_cast<std::size_t>(r)] =
1455+
exchange_acc[iter_exchange_offset + r] =
14271456
std::move(values[i]);
14281457
}
14291458

@@ -1435,8 +1464,7 @@ struct subgroup_radix_sort
14351464
if (id < n)
14361465
values[i] = std::move(
14371466
exchange_acc[iter_exchange_offset +
1438-
static_cast<std::size_t>(
1439-
id)]);
1467+
id]);
14401468
}
14411469

14421470
sycl::group_barrier(ndit.get_group());
@@ -1601,11 +1629,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16011629
using CountT = std::uint32_t;
16021630

16031631
// memory for storing count and offset values
1604-
CountT *count_ptr =
1605-
sycl::malloc_device<CountT>(n_iters * n_counts, exec_q);
1606-
if (nullptr == count_ptr) {
1607-
throw std::runtime_error("Could not allocate USM-device memory");
1608-
}
1632+
auto count_owner =
1633+
dpctl::tensor::alloc_utils::smart_malloc_device<CountT>(
1634+
n_iters * n_counts, exec_q);
1635+
1636+
CountT *count_ptr = count_owner.get();
16091637

16101638
constexpr std::uint32_t zero_radix_iter{0};
16111639

@@ -1618,25 +1646,17 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16181646
n_counts, count_ptr, proj_op,
16191647
is_ascending, depends);
16201648

1621-
sort_ev = exec_q.submit([=](sycl::handler &cgh) {
1622-
cgh.depends_on(sort_ev);
1623-
const sycl::context &ctx = exec_q.get_context();
1624-
1625-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1626-
cgh.host_task(
1627-
[ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); });
1628-
});
1649+
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
1650+
exec_q, {sort_ev}, count_owner);
16291651

16301652
return sort_ev;
16311653
}
16321654

1633-
ValueT *tmp_arr =
1634-
sycl::malloc_device<ValueT>(n_iters * n_to_sort, exec_q);
1635-
if (nullptr == tmp_arr) {
1636-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1637-
sycl_free_noexcept(count_ptr, exec_q);
1638-
throw std::runtime_error("Could not allocate USM-device memory");
1639-
}
1655+
auto tmp_arr_owner =
1656+
dpctl::tensor::alloc_utils::smart_malloc_device<ValueT>(
1657+
n_iters * n_to_sort, exec_q);
1658+
1659+
ValueT *tmp_arr = tmp_arr_owner.get();
16401660

16411661
// iterations per each bucket
16421662
assert("Number of iterations must be even" && radix_iters % 2 == 0);
@@ -1670,17 +1690,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
16701690
}
16711691
}
16721692

1673-
sort_ev = exec_q.submit([=](sycl::handler &cgh) {
1674-
cgh.depends_on(sort_ev);
1675-
1676-
const sycl::context &ctx = exec_q.get_context();
1677-
1678-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1679-
cgh.host_task([ctx, count_ptr, tmp_arr]() {
1680-
sycl_free_noexcept(tmp_arr, ctx);
1681-
sycl_free_noexcept(count_ptr, ctx);
1682-
});
1683-
});
1693+
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
1694+
exec_q, {sort_ev}, tmp_arr_owner, count_owner);
16841695
}
16851696

16861697
return sort_ev;
@@ -1782,57 +1793,38 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17821793
reinterpret_cast<IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
17831794

17841795
const std::size_t total_nelems = iter_nelems * sort_nelems;
1785-
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
1786-
IndexTy *workspace = sycl::malloc_device<IndexTy>(
1787-
padded_total_nelems + total_nelems, exec_q);
1796+
auto workspace_owner =
1797+
dpctl::tensor::alloc_utils::smart_malloc_device<IndexTy>(total_nelems,
1798+
exec_q);
17881799

1789-
if (nullptr == workspace) {
1790-
throw std::runtime_error("Could not allocate workspace on device");
1791-
}
1800+
// get raw USM pointer
1801+
IndexTy *workspace = workspace_owner.get();
17921802

17931803
using IdentityProjT = radix_sort_details::IdentityProj;
17941804
using IndexedProjT =
17951805
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
17961806
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
17971807

1798-
sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
1799-
cgh.depends_on(depends);
1808+
using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;
18001809

1801-
using KernelName = radix_argsort_iota_krn<argTy, IndexTy>;
1810+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
18021811

1803-
cgh.parallel_for<KernelName>(
1804-
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
1805-
size_t i = id[0];
1806-
IndexTy sort_id = static_cast<IndexTy>(i);
1807-
workspace[i] = sort_id;
1808-
});
1809-
});
1812+
sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
1813+
exec_q, workspace, total_nelems, depends);
18101814

18111815
sycl::event radix_sort_ev =
18121816
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
18131817
exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op,
18141818
sort_ascending, {iota_ev});
18151819

1816-
sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) {
1817-
cgh.depends_on(radix_sort_ev);
1818-
1819-
using KernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1820-
1821-
cgh.parallel_for<KernelName>(
1822-
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
1823-
IndexTy linear_index = res_tp[id];
1824-
res_tp[id] = (linear_index % sort_nelems);
1825-
});
1826-
});
1827-
1828-
sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
1829-
cgh.depends_on(map_back_ev);
1820+
using MapBackKernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
1821+
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
18301822

1831-
const sycl::context &ctx = exec_q.get_context();
1823+
sycl::event map_back_ev = map_back_impl<MapBackKernelName, IndexTy>(
1824+
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev});
18321825

1833-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1834-
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
1835-
});
1826+
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1827+
exec_q, {map_back_ev}, workspace_owner);
18361828

18371829
return cleanup_ev;
18381830
}

0 commit comments

Comments
 (0)