@@ -1414,11 +1414,11 @@ sort_pairs_impl(_ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1414
1414
// DplExtrasAlgorithm|sort_pairs
1415
1415
// DPCT_DEPENDENCY_END
1416
1416
// DPCT_CODE
1417
- template <typename _ExecutionPolicy, typename key_t , typename value_t ,
1418
- typename OffsetIteratorT>
1417
+ template <typename _ExecutionPolicy, typename key_t , typename key_out_t ,
1418
+ typename value_t , typename value_out_t , typename OffsetIteratorT>
1419
1419
inline void segmented_sort_pairs_by_parallel_sorts (
1420
- _ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in ,
1421
- value_t values_out, int64_t n, int64_t nsegments,
1420
+ _ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1421
+ value_out_t values_in, value_t values_out, int64_t n, int64_t nsegments,
1422
1422
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
1423
1423
bool descending = false , int begin_bit = 0 ,
1424
1424
int end_bit = sizeof (typename ::std::iterator_traits<key_t >::value_type) *
@@ -1457,11 +1457,11 @@ inline void segmented_sort_pairs_by_parallel_sorts(
1457
1457
// DplExtrasAlgorithm|sort_pairs
1458
1458
// DPCT_DEPENDENCY_END
1459
1459
// DPCT_CODE
1460
- template <typename _ExecutionPolicy, typename key_t , typename value_t ,
1461
- typename OffsetIteratorT>
1460
+ template <typename _ExecutionPolicy, typename key_t , typename key_out_t ,
1461
+ typename value_t , typename value_out_t , typename OffsetIteratorT>
1462
1462
inline void segmented_sort_pairs_by_parallel_for_of_sorts (
1463
- _ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in ,
1464
- value_t values_out, int64_t n, int64_t nsegments,
1463
+ _ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1464
+ value_t values_in, value_out_t values_out, int64_t n, int64_t nsegments,
1465
1465
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
1466
1466
bool descending = false , int begin_bit = 0 ,
1467
1467
int end_bit = sizeof (typename ::std::iterator_traits<key_t >::value_type) *
@@ -1489,11 +1489,11 @@ inline void segmented_sort_pairs_by_parallel_for_of_sorts(
1489
1489
// DplExtrasAlgorithm|sort_pairs
1490
1490
// DPCT_DEPENDENCY_END
1491
1491
// DPCT_CODE
1492
- template <typename _ExecutionPolicy, typename key_t , typename value_t ,
1493
- typename OffsetIteratorT>
1492
+ template <typename _ExecutionPolicy, typename key_t , typename key_out_t ,
1493
+ typename value_t , typename value_out_t , typename OffsetIteratorT>
1494
1494
inline void segmented_sort_pairs_by_two_pair_sorts (
1495
- _ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in ,
1496
- value_t values_out, int64_t n, int64_t nsegments,
1495
+ _ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1496
+ value_out_t values_in, value_t values_out, int64_t n, int64_t nsegments,
1497
1497
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
1498
1498
bool descending = false , int begin_bit = 0 ,
1499
1499
int end_bit = sizeof (typename ::std::iterator_traits<key_t >::value_type) *
@@ -1540,29 +1540,26 @@ inline void segmented_sort_pairs_by_two_pair_sorts(
1540
1540
// coordinate to mark segments
1541
1541
policy.queue ()
1542
1542
.submit ([&](sycl::handler &h) {
1543
- h.parallel_for (sycl::nd_range<1 >{work_group_size, work_group_size},
1544
- ([=](sycl::nd_item<1 > item) {
1545
- auto sub_group = item.get_sub_group ();
1546
- ::std::size_t num_subgroups =
1547
- sub_group.get_group_range ().size ();
1548
- ::std::size_t local_size =
1549
- sub_group.get_local_range ().size ();
1550
-
1551
- ::std::size_t sub_group_id =
1552
- sub_group.get_group_id ();
1553
- while (sub_group_id < nsegments) {
1554
- ::std::size_t subgroup_local_id =
1555
- sub_group.get_local_id ();
1556
- std::size_t i = begin_offsets[sub_group_id];
1557
- std::size_t end = end_offsets[sub_group_id];
1558
- while (i + subgroup_local_id < end) {
1559
- segments[i + subgroup_local_id] =
1560
- sub_group_id;
1561
- i += local_size;
1562
- }
1563
- sub_group_id += num_subgroups;
1564
- }
1565
- }));
1543
+ h.parallel_for (
1544
+ sycl::nd_range<1 >{work_group_size, work_group_size},
1545
+ ([=](sycl::nd_item<1 > item) {
1546
+ auto sub_group = item.get_sub_group ();
1547
+ ::std::size_t num_subgroups =
1548
+ sub_group.get_group_range ().size ();
1549
+ ::std::size_t local_size = sub_group.get_local_range ().size ();
1550
+
1551
+ ::std::size_t sub_group_id = sub_group.get_group_id ();
1552
+ while (sub_group_id < nsegments) {
1553
+ ::std::size_t subgroup_local_id = sub_group.get_local_id ();
1554
+ std::size_t i = begin_offsets[sub_group_id];
1555
+ std::size_t end = end_offsets[sub_group_id];
1556
+ while (i + subgroup_local_id < end) {
1557
+ segments[i + subgroup_local_id] = sub_group_id;
1558
+ i += local_size;
1559
+ }
1560
+ sub_group_id += num_subgroups;
1561
+ }
1562
+ }));
1566
1563
})
1567
1564
.wait ();
1568
1565
} else {
@@ -1716,13 +1713,18 @@ inline void sort_keys(
1716
1713
// DplExtrasAlgorithm|segmented_sort_pairs_by_parallel_sorts
1717
1714
// DplExtrasAlgorithm|segmented_sort_pairs_by_parallel_for_of_sorts
1718
1715
// DplExtrasAlgorithm|segmented_sort_pairs_by_two_pair_sorts
1716
+ // DplExtrasVector|is_iterator
1719
1717
// DPCT_DEPENDENCY_END
1720
1718
// DPCT_CODE
1721
- template <typename _ExecutionPolicy, typename key_t , typename value_t ,
1722
- typename OffsetIteratorT>
1723
- inline void segmented_sort_pairs (
1724
- _ExecutionPolicy &&policy, key_t keys_in, key_t keys_out, value_t values_in,
1725
- value_t values_out, int64_t n, int64_t nsegments,
1719
+ template <typename _ExecutionPolicy, typename key_t , typename key_out_t ,
1720
+ typename value_t , typename value_out_t , typename OffsetIteratorT>
1721
+ inline ::std::enable_if_t <dpct::internal::is_iterator<key_t >::value &&
1722
+ dpct::internal::is_iterator<key_out_t >::value &&
1723
+ dpct::internal::is_iterator<value_t >::value &&
1724
+ dpct::internal::is_iterator<value_out_t >::value>
1725
+ segmented_sort_pairs (
1726
+ _ExecutionPolicy &&policy, key_t keys_in, key_out_t keys_out,
1727
+ value_t values_in, value_out_t values_out, int64_t n, int64_t nsegments,
1726
1728
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
1727
1729
bool descending = false , int begin_bit = 0 ,
1728
1730
int end_bit = sizeof (typename ::std::iterator_traits<key_t >::value_type) *
0 commit comments