Skip to content

Commit afca57c

Browse files
[SYCLomatic] Fix for in/out types segmented sort (#483)
Signed-off-by: Dan Hoeflinger <[email protected]>
1 parent dbba3c0 commit afca57c

File tree

2 files changed

+83
-80
lines changed

2 files changed

+83
-80
lines changed

clang/runtime/dpct-rt/include/dpl_extras/algorithm.h.inc

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,11 +1414,11 @@ sort_pairs_impl(_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
14141414
// DplExtrasAlgorithm|sort_pairs
14151415
// DPCT_DEPENDENCY_END
14161416
// DPCT_CODE
1417-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1418-
typename OffsetIteratorT>
1417+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1418+
typename value_t, typename value_out_t, typename OffsetIteratorT>
14191419
inline void segmented_sort_pairs_by_parallel_sorts(
1420-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1421-
value_t values_out, int64_t n, int64_t nsegments,
1420+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1421+
value_out_t values_in, value_t values_out, int64_t n, int64_t nsegments,
14221422
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
14231423
bool descending = false, int begin_bit = 0,
14241424
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *
@@ -1457,11 +1457,11 @@ inline void segmented_sort_pairs_by_parallel_sorts(
14571457
// DplExtrasAlgorithm|sort_pairs
14581458
// DPCT_DEPENDENCY_END
14591459
// DPCT_CODE
1460-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1461-
typename OffsetIteratorT>
1460+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1461+
typename value_t, typename value_out_t, typename OffsetIteratorT>
14621462
inline void segmented_sort_pairs_by_parallel_for_of_sorts(
1463-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1464-
value_t values_out, int64_t n, int64_t nsegments,
1463+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1464+
value_t values_in, value_out_t values_out, int64_t n, int64_t nsegments,
14651465
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
14661466
bool descending = false, int begin_bit = 0,
14671467
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *
@@ -1489,11 +1489,11 @@ inline void segmented_sort_pairs_by_parallel_for_of_sorts(
14891489
// DplExtrasAlgorithm|sort_pairs
14901490
// DPCT_DEPENDENCY_END
14911491
// DPCT_CODE
1492-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1493-
typename OffsetIteratorT>
1492+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1493+
typename value_t, typename value_out_t, typename OffsetIteratorT>
14941494
inline void segmented_sort_pairs_by_two_pair_sorts(
1495-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1496-
value_t values_out, int64_t n, int64_t nsegments,
1495+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1496+
value_out_t values_in, value_t values_out, int64_t n, int64_t nsegments,
14971497
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
14981498
bool descending = false, int begin_bit = 0,
14991499
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *
@@ -1540,29 +1540,26 @@ inline void segmented_sort_pairs_by_two_pair_sorts(
15401540
// coordinate to mark segments
15411541
policy.queue()
15421542
.submit([&](sycl::handler &h) {
1543-
h.parallel_for(sycl::nd_range<1>{work_group_size, work_group_size},
1544-
([=](sycl::nd_item<1> item) {
1545-
auto sub_group = item.get_sub_group();
1546-
::std::size_t num_subgroups =
1547-
sub_group.get_group_range().size();
1548-
::std::size_t local_size =
1549-
sub_group.get_local_range().size();
1550-
1551-
::std::size_t sub_group_id =
1552-
sub_group.get_group_id();
1553-
while (sub_group_id < nsegments) {
1554-
::std::size_t subgroup_local_id =
1555-
sub_group.get_local_id();
1556-
std::size_t i = begin_offsets[sub_group_id];
1557-
std::size_t end = end_offsets[sub_group_id];
1558-
while (i + subgroup_local_id < end) {
1559-
segments[i + subgroup_local_id] =
1560-
sub_group_id;
1561-
i += local_size;
1562-
}
1563-
sub_group_id += num_subgroups;
1564-
}
1565-
}));
1543+
h.parallel_for(
1544+
sycl::nd_range<1>{work_group_size, work_group_size},
1545+
([=](sycl::nd_item<1> item) {
1546+
auto sub_group = item.get_sub_group();
1547+
::std::size_t num_subgroups =
1548+
sub_group.get_group_range().size();
1549+
::std::size_t local_size = sub_group.get_local_range().size();
1550+
1551+
::std::size_t sub_group_id = sub_group.get_group_id();
1552+
while (sub_group_id < nsegments) {
1553+
::std::size_t subgroup_local_id = sub_group.get_local_id();
1554+
std::size_t i = begin_offsets[sub_group_id];
1555+
std::size_t end = end_offsets[sub_group_id];
1556+
while (i + subgroup_local_id < end) {
1557+
segments[i + subgroup_local_id] = sub_group_id;
1558+
i += local_size;
1559+
}
1560+
sub_group_id += num_subgroups;
1561+
}
1562+
}));
15661563
})
15671564
.wait();
15681565
} else {
@@ -1716,13 +1713,18 @@ inline void sort_keys(
17161713
// DplExtrasAlgorithm|segmented_sort_pairs_by_parallel_sorts
17171714
// DplExtrasAlgorithm|segmented_sort_pairs_by_parallel_for_of_sorts
17181715
// DplExtrasAlgorithm|segmented_sort_pairs_by_two_pair_sorts
1716+
// DplExtrasVector|is_iterator
17191717
// DPCT_DEPENDENCY_END
17201718
// DPCT_CODE
1721-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1722-
typename OffsetIteratorT>
1723-
inline void segmented_sort_pairs(
1724-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1725-
value_t values_out, int64_t n, int64_t nsegments,
1719+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1720+
typename value_t, typename value_out_t, typename OffsetIteratorT>
1721+
inline ::std::enable_if_t<dpct::internal::is_iterator<key_t>::value &&
1722+
dpct::internal::is_iterator<key_out_t>::value &&
1723+
dpct::internal::is_iterator<value_t>::value &&
1724+
dpct::internal::is_iterator<value_out_t>::value>
1725+
segmented_sort_pairs(
1726+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1727+
value_t values_in, value_out_t values_out, int64_t n, int64_t nsegments,
17261728
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
17271729
bool descending = false, int begin_bit = 0,
17281730
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *

clang/test/dpct/helper_files_ref/include/dpl_extras/algorithm.h

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,11 +1180,11 @@ sort_pairs_impl(_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
11801180
sycl::free(temp_keys_out, policy.queue());
11811181
}
11821182

1183-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1184-
typename OffsetIteratorT>
1183+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1184+
typename value_t, typename value_out_t, typename OffsetIteratorT>
11851185
inline void segmented_sort_pairs_by_parallel_sorts(
1186-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1187-
value_t values_out, int64_t n, int64_t nsegments,
1186+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1187+
value_out_t values_in, value_t values_out, int64_t n, int64_t nsegments,
11881188
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
11891189
bool descending = false, int begin_bit = 0,
11901190
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *
@@ -1217,11 +1217,11 @@ inline void segmented_sort_pairs_by_parallel_sorts(
12171217
sycl::free(host_accessible_offset_ends, policy.queue());
12181218
}
12191219

1220-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1221-
typename OffsetIteratorT>
1220+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1221+
typename value_t, typename value_out_t, typename OffsetIteratorT>
12221222
inline void segmented_sort_pairs_by_parallel_for_of_sorts(
1223-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1224-
value_t values_out, int64_t n, int64_t nsegments,
1223+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1224+
value_t values_in, value_out_t values_out, int64_t n, int64_t nsegments,
12251225
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
12261226
bool descending = false, int begin_bit = 0,
12271227
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *
@@ -1243,11 +1243,11 @@ inline void segmented_sort_pairs_by_parallel_for_of_sorts(
12431243
policy.queue().wait();
12441244
}
12451245

1246-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1247-
typename OffsetIteratorT>
1246+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1247+
typename value_t, typename value_out_t, typename OffsetIteratorT>
12481248
inline void segmented_sort_pairs_by_two_pair_sorts(
1249-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1250-
value_t values_out, int64_t n, int64_t nsegments,
1249+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1250+
value_out_t values_in, value_t values_out, int64_t n, int64_t nsegments,
12511251
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
12521252
bool descending = false, int begin_bit = 0,
12531253
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *
@@ -1294,29 +1294,26 @@ inline void segmented_sort_pairs_by_two_pair_sorts(
12941294
// coordinate to mark segments
12951295
policy.queue()
12961296
.submit([&](sycl::handler &h) {
1297-
h.parallel_for(sycl::nd_range<1>{work_group_size, work_group_size},
1298-
([=](sycl::nd_item<1> item) {
1299-
auto sub_group = item.get_sub_group();
1300-
::std::size_t num_subgroups =
1301-
sub_group.get_group_range().size();
1302-
::std::size_t local_size =
1303-
sub_group.get_local_range().size();
1304-
1305-
::std::size_t sub_group_id =
1306-
sub_group.get_group_id();
1307-
while (sub_group_id < nsegments) {
1308-
::std::size_t subgroup_local_id =
1309-
sub_group.get_local_id();
1310-
std::size_t i = begin_offsets[sub_group_id];
1311-
std::size_t end = end_offsets[sub_group_id];
1312-
while (i + subgroup_local_id < end) {
1313-
segments[i + subgroup_local_id] =
1314-
sub_group_id;
1315-
i += local_size;
1316-
}
1317-
sub_group_id += num_subgroups;
1318-
}
1319-
}));
1297+
h.parallel_for(
1298+
sycl::nd_range<1>{work_group_size, work_group_size},
1299+
([=](sycl::nd_item<1> item) {
1300+
auto sub_group = item.get_sub_group();
1301+
::std::size_t num_subgroups =
1302+
sub_group.get_group_range().size();
1303+
::std::size_t local_size = sub_group.get_local_range().size();
1304+
1305+
::std::size_t sub_group_id = sub_group.get_group_id();
1306+
while (sub_group_id < nsegments) {
1307+
::std::size_t subgroup_local_id = sub_group.get_local_id();
1308+
std::size_t i = begin_offsets[sub_group_id];
1309+
std::size_t end = end_offsets[sub_group_id];
1310+
while (i + subgroup_local_id < end) {
1311+
segments[i + subgroup_local_id] = sub_group_id;
1312+
i += local_size;
1313+
}
1314+
sub_group_id += num_subgroups;
1315+
}
1316+
}));
13201317
})
13211318
.wait();
13221319
} else {
@@ -1434,11 +1431,15 @@ inline void sort_keys(
14341431
keys.swap();
14351432
}
14361433

1437-
template <typename _ExecutionPolicy, typename key_t, typename value_t,
1438-
typename OffsetIteratorT>
1439-
inline void segmented_sort_pairs(
1440-
_ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1441-
value_t values_out, int64_t n, int64_t nsegments,
1434+
template <typename _ExecutionPolicy, typename key_t, typename key_out_t,
1435+
typename value_t, typename value_out_t, typename OffsetIteratorT>
1436+
inline ::std::enable_if_t<dpct::internal::is_iterator<key_t>::value &&
1437+
dpct::internal::is_iterator<key_out_t>::value &&
1438+
dpct::internal::is_iterator<value_t>::value &&
1439+
dpct::internal::is_iterator<value_out_t>::value>
1440+
segmented_sort_pairs(
1441+
_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1442+
value_t values_in, value_out_t values_out, int64_t n, int64_t nsegments,
14421443
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
14431444
bool descending = false, int begin_bit = 0,
14441445
int end_bit = sizeof(typename ::std::iterator_traits<key_t>::value_type) *

0 commit comments

Comments
 (0)