Skip to content

Commit 087e4ea

Browse files
[stdpar][algorithms] Add support for min_element and max_element (AdaptiveCpp#1783)
* [stdpar][algorithms] Add support for `min_element` and `max_element` * Remove debugging printfs from tests * Align behaviour to libstdc++ when using `std::less_equal` * Refactor comparison logic for brevity --------- Co-authored-by: Aksel Alpay <[email protected]>
1 parent 37124f8 commit 087e4ea

File tree

9 files changed

+825
-0
lines changed

9 files changed

+825
-0
lines changed

doc/algorithms.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,34 @@ sycl::event merge(sycl::queue& q,
437437
ForwardIt3 d_first, Compare comp = std::less<>{},
438438
const std::vector<sycl::event>& deps = {});
439439

440+
template <class ForwardIt>
441+
sycl::event min_element(sycl::queue &q,
442+
util::allocation_group &scratch_allocations,
443+
ForwardIt first, ForwardIt last,
444+
std::pair<ForwardIt, typename std::iterator_traits<ForwardIt>::value_type> *out,
445+
const std::vector<sycl::event> &deps= {});
446+
447+
template <class ForwardIt, class Compare>
448+
sycl::event min_element(sycl::queue &q,
449+
util::allocation_group &scratch_allocations,
450+
ForwardIt first, ForwardIt last, Compare comp,
451+
std::pair<ForwardIt, typename std::iterator_traits<ForwardIt>::value_type> *out,
452+
const std::vector<sycl::event> &deps= {});
453+
454+
template <class ForwardIt>
455+
sycl::event max_element(sycl::queue &q,
456+
util::allocation_group &scratch_allocations,
457+
ForwardIt first, ForwardIt last,
458+
std::pair<ForwardIt, typename std::iterator_traits<ForwardIt>::value_type> *out,
459+
const std::vector<sycl::event> &deps= {});
460+
461+
template <class ForwardIt, class Compare>
462+
sycl::event max_element(sycl::queue &q,
463+
util::allocation_group &scratch_allocations,
464+
ForwardIt first, ForwardIt last, Compare comp,
465+
std::pair<ForwardIt, typename std::iterator_traits<ForwardIt>::value_type> *out,
466+
const std::vector<sycl::event> &deps= {});
467+
440468
}
441469

442470

doc/stdpar.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ Offloading is implemented for the following STL algorithms:
5959
|`equal` | |
6060
|`merge` | |
6161
|`sort` | may not scale optimally for large problems |
62+
|`min_element` | |
63+
|`max_element` | |
6264
|`is_sorted_until` | both overloads |
6365
|`is_sorted` | both overloads |
6466
|`inclusive_scan` | |

include/hipSYCL/algorithms/algorithm.hpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,190 @@ sycl::event count_if(sycl::queue &q, util::allocation_group &scratch_allocations
870870
deps);
871871
}
872872

873+
template <class ForwardIt>
874+
sycl::event
875+
min_element(sycl::queue &q, util::allocation_group &scratch_allocations,
876+
ForwardIt first, ForwardIt last,
877+
std::pair<ForwardIt, typename
878+
std::iterator_traits<ForwardIt>::value_type> *out,
879+
const std::vector<sycl::event> &deps= {}) {
880+
auto problem_size = std::distance(first, last);
881+
if(problem_size == 0)
882+
return sycl::event{};
883+
884+
using ValueT = typename std::iterator_traits<ForwardIt>::value_type;
885+
using MinPair = std::pair<ForwardIt, ValueT>;
886+
887+
auto kernel = [=](sycl::id<1> idx, auto& reducer) {
888+
auto input = first;
889+
std::advance(input, idx[0]);
890+
MinPair p = std::make_pair(input, *input);
891+
reducer.combine(p);
892+
};
893+
894+
auto reduce = [first] (MinPair a, MinPair b) {
895+
// Preserve strict total order over two equivalent
896+
// pointers, i.e. return the element that appears
897+
// in the sequence nearest to first.
898+
if (!(a.second < b.second) && !(b.second < a.second)) {
899+
if (std::distance(first, a.first) < std::distance(first, b.first))
900+
return a;
901+
else
902+
return b;
903+
}
904+
#if __cplusplus < 202002L
905+
else if (a.second < b.second)
906+
#else
907+
else if (std::less{}(a.second, b.second))
908+
#endif
909+
return a;
910+
else
911+
return b;
912+
};
913+
914+
MinPair init = std::make_pair(first, *first);
915+
916+
return detail::transform_reduce_impl(q, scratch_allocations, out, init,
917+
problem_size, kernel, reduce, deps);
918+
}
919+
920+
template <class ForwardIt, class Compare>
921+
sycl::event
922+
min_element(sycl::queue &q, util::allocation_group &scratch_allocations,
923+
ForwardIt first, ForwardIt last, Compare comp,
924+
std::pair<ForwardIt, typename
925+
std::iterator_traits<ForwardIt>::value_type> *out,
926+
const std::vector<sycl::event> &deps= {}) {
927+
auto problem_size = std::distance(first, last);
928+
if(problem_size == 0)
929+
return sycl::event{};
930+
931+
using ValueT = typename std::iterator_traits<ForwardIt>::value_type;
932+
using MinPair = std::pair<ForwardIt, ValueT>;
933+
934+
auto kernel = [=](sycl::id<1> idx, auto& reducer) {
935+
auto input = first;
936+
std::advance(input, idx[0]);
937+
MinPair p = std::make_pair(input, *input);
938+
reducer.combine(p);
939+
};
940+
941+
auto reduce = [comp, first] (MinPair a, MinPair b) {
942+
// Comp used for associative containers must always
943+
// return false for equal values. (Effective STL, Item 21)
944+
// In cases where it does not (for eg. std::less_equal),
945+
// implementation aligns behaviour to libstdc++, returning
946+
// the element furthest from first.
947+
if (std::distance(first, a.first) < std::distance(first, b.first))
948+
if (comp(b.second, a.second) == false)
949+
return a;
950+
else
951+
return b;
952+
else
953+
if (comp(a.second, b.second) == false)
954+
return b;
955+
else
956+
return a;
957+
};
958+
959+
MinPair init = std::make_pair(first, *first);
960+
961+
return detail::transform_reduce_impl(q, scratch_allocations, out, init,
962+
problem_size, kernel, reduce, deps);
963+
}
964+
965+
template <class ForwardIt>
966+
sycl::event
967+
max_element(sycl::queue &q, util::allocation_group &scratch_allocations,
968+
ForwardIt first, ForwardIt last,
969+
std::pair<ForwardIt, typename
970+
std::iterator_traits<ForwardIt>::value_type> *out,
971+
const std::vector<sycl::event> &deps= {}) {
972+
auto problem_size = std::distance(first, last);
973+
if(problem_size == 0)
974+
return sycl::event{};
975+
976+
using ValueT = typename std::iterator_traits<ForwardIt>::value_type;
977+
using MaxPair = std::pair<ForwardIt, ValueT>;
978+
979+
auto kernel = [=](sycl::id<1> idx, auto& reducer) {
980+
auto input = first;
981+
std::advance(input, idx[0]);
982+
MaxPair p = std::make_pair(input, *input);
983+
reducer.combine(p);
984+
};
985+
986+
auto reduce = [first] (MaxPair a, MaxPair b) {
987+
// Preserve strict total order over two equivalent
988+
// pointers, i.e. return the element that appears
989+
// in the sequence nearest to first.
990+
if (!(a.second < b.second) && !(b.second < a.second)) {
991+
if (std::distance(first, a.first) < std::distance(first, b.first))
992+
return a;
993+
else
994+
return b;
995+
}
996+
#if __cplusplus < 202002L
997+
else if (a.second < b.second)
998+
#else
999+
else if (std::less{}(a.second, b.second))
1000+
#endif
1001+
return b;
1002+
else
1003+
return a;
1004+
};
1005+
1006+
MaxPair init = std::make_pair(first, *first);
1007+
1008+
return detail::transform_reduce_impl(q, scratch_allocations, out, init,
1009+
problem_size, kernel, reduce, deps);
1010+
}
1011+
1012+
template <class ForwardIt, class Compare>
1013+
sycl::event
1014+
max_element(sycl::queue &q, util::allocation_group &scratch_allocations,
1015+
ForwardIt first, ForwardIt last, Compare comp,
1016+
std::pair<ForwardIt, typename
1017+
std::iterator_traits<ForwardIt>::value_type> *out,
1018+
const std::vector<sycl::event> &deps= {}) {
1019+
auto problem_size = std::distance(first, last);
1020+
if(problem_size == 0)
1021+
return sycl::event{};
1022+
1023+
using ValueT = typename std::iterator_traits<ForwardIt>::value_type;
1024+
using MaxPair = std::pair<ForwardIt, ValueT>;
1025+
1026+
auto kernel = [=](sycl::id<1> idx, auto& reducer) {
1027+
auto input = first;
1028+
std::advance(input, idx[0]);
1029+
MaxPair p = std::make_pair(input, *input);
1030+
reducer.combine(p);
1031+
};
1032+
1033+
auto reduce = [comp, first] (MaxPair a, MaxPair b) {
1034+
// Comp used for associative containers must always
1035+
// return false for equal values. (Effective STL, Item 21)
1036+
// In cases where it does not (for eg. std::less_equal),
1037+
// implementation aligns behaviour to libstdc++, returning
1038+
// the element furthest from first.
1039+
if (std::distance(first, a.first) < std::distance(first, b.first))
1040+
if(comp(a.second, b.second) == false)
1041+
return a;
1042+
else
1043+
return b;
1044+
else
1045+
if(comp(b.second, a.second) == false)
1046+
return b;
1047+
else
1048+
return a;
1049+
};
1050+
1051+
MaxPair init = std::make_pair(first, *first);
1052+
1053+
return detail::transform_reduce_impl(q, scratch_allocations, out, init,
1054+
problem_size, kernel, reduce, deps);
1055+
}
1056+
8731057
template <class ForwardIt1, class ForwardIt2>
8741058
sycl::event equal(sycl::queue &q,
8751059
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,

include/hipSYCL/std/stdpar/detail/algorithm_fwd.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,26 @@ typename std::iterator_traits<ForwardIt>::difference_type
191191
count_if( hipsycl::stdpar::par_unseq, ForwardIt first, ForwardIt last,
192192
UnaryPredicate p );
193193

194+
template <class ForwardIt>
195+
HIPSYCL_STDPAR_ENTRYPOINT
196+
ForwardIt min_element(hipsycl::stdpar::par_unseq, ForwardIt first,
197+
ForwardIt last);
198+
199+
template <class ForwardIt, class Compare>
200+
HIPSYCL_STDPAR_ENTRYPOINT
201+
ForwardIt min_element(hipsycl::stdpar::par_unseq, ForwardIt first,
202+
ForwardIt last, Compare comp);
203+
204+
template <class ForwardIt>
205+
HIPSYCL_STDPAR_ENTRYPOINT
206+
ForwardIt max_element(hipsycl::stdpar::par_unseq, ForwardIt first,
207+
ForwardIt last);
208+
209+
template <class ForwardIt, class Compare>
210+
HIPSYCL_STDPAR_ENTRYPOINT
211+
ForwardIt max_element(hipsycl::stdpar::par_unseq, ForwardIt first,
212+
ForwardIt last, Compare comp);
213+
194214
template<class ForwardIt>
195215
HIPSYCL_STDPAR_ENTRYPOINT
196216
bool is_sorted(hipsycl::stdpar::par_unseq, ForwardIt first, ForwardIt last);
@@ -257,6 +277,24 @@ typename std::iterator_traits<ForwardIt>::difference_type
257277
count_if( hipsycl::stdpar::par, ForwardIt first, ForwardIt last,
258278
UnaryPredicate p );
259279

280+
template <class ForwardIt>
281+
HIPSYCL_STDPAR_ENTRYPOINT
282+
ForwardIt min_element(hipsycl::stdpar::par, ForwardIt first, ForwardIt last);
283+
284+
template <class ForwardIt, class Compare>
285+
HIPSYCL_STDPAR_ENTRYPOINT
286+
ForwardIt min_element(hipsycl::stdpar::par, ForwardIt first, ForwardIt last,
287+
Compare comp);
288+
289+
template <class ForwardIt>
290+
HIPSYCL_STDPAR_ENTRYPOINT
291+
ForwardIt max_element(hipsycl::stdpar::par, ForwardIt first, ForwardIt last);
292+
293+
template <class ForwardIt, class Compare>
294+
HIPSYCL_STDPAR_ENTRYPOINT
295+
ForwardIt max_element(hipsycl::stdpar::par, ForwardIt first, ForwardIt last,
296+
Compare comp);
297+
260298
template<class ForwardIt>
261299
HIPSYCL_STDPAR_ENTRYPOINT
262300
bool is_sorted(hipsycl::stdpar::par, ForwardIt first, ForwardIt last);

include/hipSYCL/std/stdpar/detail/offload_heuristic_db.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ struct sort {};
6666
struct is_sorted {};
6767
struct is_sorted_until {};
6868
struct merge {};
69+
struct min_element {};
70+
struct max_element {};
6971
struct inclusive_scan {};
7072
struct exclusive_scan {};
7173
struct transform_inclusive_scan {};

0 commit comments

Comments
 (0)