Skip to content

Commit 275827a

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Add sort_utils.hpp with iota_impl
Reuse that function call in sorting code-base where argsort is used.
1 parent 65f14be commit 275827a

File tree

4 files changed

+162
-42
lines changed

4 files changed

+162
-42
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "kernels/dpctl_tensor_types.hpp"
3535
#include "kernels/sorting/search_sorted_detail.hpp"
36+
#include "kernels/sorting/sort_utils.hpp"
3637

3738
namespace dpctl
3839
{
@@ -811,20 +812,27 @@ sycl::event stable_argsort_axis1_contig_impl(
811812

812813
const size_t total_nelems = iter_nelems * sort_nelems;
813814

815+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
816+
817+
using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;
818+
819+
#if 1
820+
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
821+
exec_q, res_tp, total_nelems, depends);
822+
823+
#else
814824
sycl::event populate_indexed_data_ev =
815825
exec_q.submit([&](sycl::handler &cgh) {
816826
cgh.depends_on(depends);
817827
818828
const sycl::range<1> range{total_nelems};
819829
820-
using KernelName =
821-
populate_index_data_krn<argTy, IndexTy, ValueComp>;
822-
823-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
830+
cgh.parallel_for<IotaKernelName>(range, [=](sycl::id<1> id) {
824831
size_t i = id[0];
825832
res_tp[i] = static_cast<IndexTy>(i);
826833
});
827834
});
835+
#endif
828836

829837
// Sort segments of the array
830838
sycl::event base_sort_ev =

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <sycl/sycl.hpp>
3939

4040
#include "kernels/dpctl_tensor_types.hpp"
41+
#include "kernels/sorting/sort_utils.hpp"
4142
#include "utils/sycl_alloc_utils.hpp"
4243

4344
namespace dpctl
@@ -1256,9 +1257,7 @@ struct subgroup_radix_sort
12561257
const uint16_t id = wi * block_size + i;
12571258
if (id < n)
12581259
values[i] = std::move(
1259-
this_input_arr[iter_val_offset +
1260-
static_cast<std::size_t>(
1261-
id)]);
1260+
this_input_arr[iter_val_offset + id]);
12621261
}
12631262

12641263
while (true) {
@@ -1272,8 +1271,7 @@ struct subgroup_radix_sort
12721271
// counting phase
12731272
auto pcounter =
12741273
get_accessor_pointer(counter_acc) +
1275-
static_cast<std::size_t>(wi) +
1276-
iter_counter_offset;
1274+
(wi + iter_counter_offset);
12771275

12781276
// initialize counters
12791277
#pragma unroll
@@ -1348,19 +1346,15 @@ struct subgroup_radix_sort
13481346

13491347
// scan contiguous numbers
13501348
uint16_t bin_sum[bin_count];
1351-
bin_sum[0] =
1352-
counter_acc[iter_counter_offset +
1353-
static_cast<std::size_t>(
1354-
wi * bin_count)];
1349+
const std::size_t counter_offset0 =
1350+
iter_counter_offset + wi * bin_count;
1351+
bin_sum[0] = counter_acc[counter_offset0];
13551352

13561353
#pragma unroll
13571354
for (uint16_t i = 1; i < bin_count; ++i)
13581355
bin_sum[i] =
13591356
bin_sum[i - 1] +
1360-
counter_acc
1361-
[iter_counter_offset +
1362-
static_cast<std::size_t>(
1363-
wi * bin_count + i)];
1357+
counter_acc[counter_offset0 + i];
13641358

13651359
sycl::group_barrier(ndit.get_group());
13661360

@@ -1374,10 +1368,7 @@ struct subgroup_radix_sort
13741368
// add to local sum, generate exclusive scan result
13751369
#pragma unroll
13761370
for (uint16_t i = 0; i < bin_count; ++i)
1377-
counter_acc[iter_counter_offset +
1378-
static_cast<std::size_t>(
1379-
wi * bin_count + i +
1380-
1)] =
1371+
counter_acc[counter_offset0 + i + 1] =
13811372
sum_scan + bin_sum[i];
13821373

13831374
if (wi == 0)
@@ -1407,10 +1398,8 @@ struct subgroup_radix_sort
14071398
if (r < n) {
14081399
// move the values to source range and
14091400
// destroy the values
1410-
this_output_arr
1411-
[iter_val_offset +
1412-
static_cast<std::size_t>(r)] =
1413-
std::move(values[i]);
1401+
this_output_arr[iter_val_offset + r] =
1402+
std::move(values[i]);
14141403
}
14151404
}
14161405

@@ -1422,8 +1411,7 @@ struct subgroup_radix_sort
14221411
for (uint16_t i = 0; i < block_size; ++i) {
14231412
const uint16_t r = indices[i];
14241413
if (r < n)
1425-
exchange_acc[iter_exchange_offset +
1426-
static_cast<std::size_t>(r)] =
1414+
exchange_acc[iter_exchange_offset + r] =
14271415
std::move(values[i]);
14281416
}
14291417

@@ -1435,8 +1423,7 @@ struct subgroup_radix_sort
14351423
if (id < n)
14361424
values[i] = std::move(
14371425
exchange_acc[iter_exchange_offset +
1438-
static_cast<std::size_t>(
1439-
id)]);
1426+
id]);
14401427
}
14411428

14421429
sycl::group_barrier(ndit.get_group());
@@ -1795,18 +1782,26 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17951782
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
17961783
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
17971784

1785+
using IotaKernelName = radix_argsort_iota_krn<argTy, IndexTy>;
1786+
1787+
#if 1
1788+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
1789+
1790+
sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
1791+
exec_q, workspace, total_nelems, depends);
1792+
#else
1793+
17981794
sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
17991795
cgh.depends_on(depends);
18001796
1801-
using KernelName = radix_argsort_iota_krn<argTy, IndexTy>;
1802-
1803-
cgh.parallel_for<KernelName>(
1797+
cgh.parallel_for<IotaKernelName>(
18041798
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
18051799
size_t i = id[0];
18061800
IndexTy sort_id = static_cast<IndexTy>(i);
18071801
workspace[i] = sort_id;
18081802
});
18091803
});
1804+
#endif
18101805

18111806
sycl::event radix_sort_ev =
18121807
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//=== sorting.hpp - Implementation of sorting kernels ---*-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2024 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for tensor sort/argsort operations.
23+
//===----------------------------------------------------------------------===//
24+
25+
#pragma once
26+
27+
#include <cstddef>
28+
#include <cstdint>
29+
#include <vector>
30+
31+
#include <sycl/sycl.hpp>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace kernels
38+
{
39+
namespace sort_utils_detail
40+
{
41+
42+
namespace syclexp = sycl::ext::oneapi::experimental;
43+
44+
template <class KernelName, typename T>
45+
sycl::event iota_impl(sycl::queue &exec_q,
46+
T *data,
47+
std::size_t nelems,
48+
const std::vector<sycl::event> &dependent_events)
49+
{
50+
constexpr std::uint32_t lws = 256;
51+
constexpr std::uint32_t n_wi = 4;
52+
const std::size_t n_groups = (nelems + n_wi * lws - 1) / (n_wi * lws);
53+
54+
sycl::range<1> gRange{n_groups * lws};
55+
sycl::range<1> lRange{lws};
56+
sycl::nd_range<1> ndRange{gRange, lRange};
57+
58+
sycl::event e = exec_q.submit([&](sycl::handler &cgh) {
59+
cgh.depends_on(dependent_events);
60+
cgh.parallel_for<KernelName>(ndRange, [=](sycl::nd_item<1> it) {
61+
const std::size_t gid = it.get_global_id();
62+
const auto &sg = it.get_sub_group();
63+
const std::uint32_t lane_id = sg.get_local_id()[0];
64+
65+
const std::size_t offset = (gid - lane_id) * n_wi;
66+
const std::uint32_t max_sgSize = sg.get_max_local_range()[0];
67+
68+
std::array<T, n_wi> stripe{};
69+
#pragma unroll
70+
for (std::uint32_t i = 0; i < n_wi; ++i) {
71+
stripe[i] = T(offset + lane_id + i * max_sgSize);
72+
}
73+
74+
if (offset + n_wi * max_sgSize < nelems) {
75+
constexpr auto group_ls_props = syclexp::properties{
76+
syclexp::data_placement_striped
77+
// , syclexp::full_group
78+
};
79+
80+
auto out_multi_ptr = sycl::address_space_cast<
81+
sycl::access::address_space::global_space,
82+
sycl::access::decorated::yes>(&data[offset]);
83+
84+
syclexp::group_store(sg, sycl::span<T, n_wi>{&stripe[0], n_wi},
85+
out_multi_ptr, group_ls_props);
86+
}
87+
else {
88+
for (std::size_t idx = offset + lane_id; idx < nelems;
89+
idx += max_sgSize)
90+
{
91+
data[idx] = T(idx);
92+
}
93+
}
94+
});
95+
});
96+
97+
return e;
98+
}
99+
100+
} // end of namespace sort_utils_detail
101+
} // end of namespace kernels
102+
} // end of namespace tensor
103+
} // end of namespace dpctl

dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
#include <vector>
3535

3636
#include "kernels/dpctl_tensor_types.hpp"
37-
#include "merge_sort.hpp"
38-
#include "radix_sort.hpp"
39-
#include "search_sorted_detail.hpp"
37+
#include "kernels/sorting/merge_sort.hpp"
38+
#include "kernels/sorting/radix_sort.hpp"
39+
#include "kernels/sorting/search_sorted_detail.hpp"
40+
#include "kernels/sorting/sort_utils.hpp"
4041
#include "utils/sycl_alloc_utils.hpp"
4142
#include <sycl/ext/oneapi/sub_group_mask.hpp>
4243

@@ -95,20 +96,26 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
9596
throw std::runtime_error("Unable to allocate device_memory");
9697
}
9798

99+
using IotaKernelName = topk_populate_index_data_krn<argTy, IndexTy, CompT>;
100+
101+
#if 1
102+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
103+
104+
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
105+
exec_q, index_data, iter_nelems * axis_nelems, depends);
106+
#else
98107
sycl::event populate_indexed_data_ev =
99108
exec_q.submit([&](sycl::handler &cgh) {
100109
cgh.depends_on(depends);
101110
102111
auto const &range = sycl::range<1>(iter_nelems * axis_nelems);
103112
104-
using KernelName =
105-
topk_populate_index_data_krn<argTy, IndexTy, CompT>;
106-
107-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
113+
cgh.parallel_for<IotaKernelName>(range, [=](sycl::id<1> id) {
108114
std::size_t i = id[0];
109115
index_data[i] = static_cast<IndexTy>(i);
110116
});
111117
});
118+
#endif
112119

113120
std::size_t sorted_block_size;
114121
// Sort segments of the array
@@ -480,18 +487,25 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
480487
radix_sort_details::IndexedProj<IndexTy, argTy, IdentityProjT>;
481488
const IndexedProjT proj_op{arg_tp, IdentityProjT{}};
482489

490+
using IotaKernelName = topk_iota_krn<argTy, IndexTy>;
491+
492+
#if 1
493+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
494+
495+
sycl::event iota_ev = iota_impl<IotaKernelName, IndexTy>(
496+
exec_q, workspace, total_nelems, depends);
497+
#else
483498
sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
484499
cgh.depends_on(depends);
485500
486-
using KernelName = topk_iota_krn<argTy, IndexTy>;
487-
488-
cgh.parallel_for<KernelName>(
501+
cgh.parallel_for<IotaKernelName>(
489502
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
490503
size_t i = id[0];
491504
IndexTy sort_id = static_cast<IndexTy>(i);
492505
workspace[i] = sort_id;
493506
});
494507
});
508+
#endif
495509

496510
sycl::event radix_sort_ev =
497511
radix_sort_details::parallel_radix_sort_impl<IndexTy, IndexedProjT>(

0 commit comments

Comments
 (0)