Skip to content

Commit 7dd53c1

Browse files
[stdpar][algorithms] Add support for mismatch (AdaptiveCpp#1824)
Co-authored-by: Aksel Alpay <[email protected]>
1 parent 087e4ea commit 7dd53c1

File tree

8 files changed

+648
-0
lines changed

8 files changed

+648
-0
lines changed

doc/algorithms.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,37 @@ sycl::event count_if(sycl::queue &q, util::allocation_group &scratch_allocations
352352
typename std::iterator_traits<ForwardIt>::difference_type *out,
353353
UnaryPredicate p, const std::vector<sycl::event> &deps = {});
354354

355+
/// out must point to memory that is accessible on the target device.
356+
template <class ForwardIt1, class ForwardIt2>
357+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
358+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
359+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
360+
const std::vector<sycl::event>& deps = {});
361+
362+
/// out must point to memory that is accessible on the target device.
363+
template <class ForwardIt1, class ForwardIt2, class BinaryPredicate>
364+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
365+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
366+
BinaryPredicate p,
367+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
368+
const std::vector<sycl::event>& deps = {});
369+
370+
/// out must point to memory that is accessible on the target device.
371+
template <class ForwardIt1, class ForwardIt2>
372+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
373+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
374+
ForwardIt2 last2,
375+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
376+
const std::vector<sycl::event>& deps = {});
377+
378+
/// out must point to memory that is accessible on the target device.
379+
template <class ForwardIt1, class ForwardIt2, class BinaryPredicate>
380+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
381+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
382+
ForwardIt2 last2, BinaryPredicate p,
383+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
384+
const std::vector<sycl::event>& deps = {});
385+
355386
/// The result of the operation will be stored in out.
356387
///
357388
/// out must point to device-accessible memory, and will be set to 0

doc/stdpar.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Offloading is implemented for the following STL algorithms:
5656
|`none_of` | |
5757
|`count` | |
5858
|`count_if` | |
59+
|`mismatch` | |
5960
|`equal` | |
6061
|`merge` | |
6162
|`sort` | may not scale optimally for large problems |

include/hipSYCL/algorithms/algorithm.hpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,73 @@ sycl::event count_if(sycl::queue &q, util::allocation_group &scratch_allocations
870870
deps);
871871
}
872872

873+
template <class ForwardIt1, class ForwardIt2, class BinaryPredicate>
874+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
875+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
876+
ForwardIt2 last2, BinaryPredicate p,
877+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
878+
const std::vector<sycl::event>& deps = {}) {
879+
if (first1 == last1 || first2 == last2)
880+
return sycl::event{};
881+
882+
using DiffT = typename std::iterator_traits<ForwardIt1>::difference_type;
883+
DiffT problem_size = std::min(std::distance(first1, last1),
884+
std::distance(first2, last2));
885+
886+
auto kernel = [=](sycl::id<1> idx, auto& reducer) {
887+
auto input1 = std::next(first1, idx[0]);
888+
auto input2 = std::next(first2, idx[0]);
889+
if ( p(*input1, *input2) )
890+
reducer.combine(problem_size);
891+
else
892+
reducer.combine(idx[0]);
893+
};
894+
895+
auto reduce = sycl::minimum<DiffT>{};
896+
897+
return detail::transform_reduce_impl(q, scratch_allocations, out,
898+
std::numeric_limits<DiffT>::max(),
899+
problem_size, kernel, reduce, deps);
900+
}
901+
902+
template <class ForwardIt1, class ForwardIt2, class BinaryPredicate>
903+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
904+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
905+
BinaryPredicate p,
906+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
907+
const std::vector<sycl::event>& deps = {}) {
908+
if (first1 == last1)
909+
return sycl::event{};
910+
911+
return mismatch(q, scratch_allocations, first1, last1, first2,
912+
std::next(first2, std::distance(first1, last1)), p, out, deps);
913+
}
914+
915+
template <class ForwardIt1, class ForwardIt2>
916+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
917+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
918+
ForwardIt2 last2,
919+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
920+
const std::vector<sycl::event>& deps = {}) {
921+
if (first1 == last1 || first2 == last2)
922+
return sycl::event{};
923+
924+
return mismatch(q, scratch_allocations, first1, last1, first2,
925+
last2, std::equal_to<>(), out, deps);
926+
}
927+
928+
template <class ForwardIt1, class ForwardIt2>
929+
sycl::event mismatch(sycl::queue &q, util::allocation_group &scratch_allocations,
930+
ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2,
931+
typename std::iterator_traits<ForwardIt1>::difference_type* out,
932+
const std::vector<sycl::event>& deps = {}) {
933+
if (first1 == last1)
934+
return sycl::event{};
935+
936+
return mismatch(q, scratch_allocations, first1, last1, first2,
937+
std::next(first2, std::distance(first1, last1)), out, deps);
938+
}
939+
873940
template <class ForwardIt>
874941
sycl::event
875942
min_element(sycl::queue &q, util::allocation_group &scratch_allocations,

include/hipSYCL/std/stdpar/detail/algorithm_fwd.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,31 @@ typename std::iterator_traits<ForwardIt>::difference_type
191191
count_if( hipsycl::stdpar::par_unseq, ForwardIt first, ForwardIt last,
192192
UnaryPredicate p );
193193

194+
template<class ForwardIt1, class ForwardIt2>
195+
HIPSYCL_STDPAR_ENTRYPOINT
196+
std::pair<ForwardIt1, ForwardIt2> mismatch( hipsycl::stdpar::par_unseq,
197+
ForwardIt1 first1, ForwardIt1 last1,
198+
ForwardIt2 first2 );
199+
200+
template<class ForwardIt1, class ForwardIt2, class BinaryPredicate>
201+
HIPSYCL_STDPAR_ENTRYPOINT
202+
std::pair<ForwardIt1, ForwardIt2> mismatch(hipsycl::stdpar::par_unseq,
203+
ForwardIt1 first1, ForwardIt1 last1,
204+
ForwardIt2 first2, BinaryPredicate p);
205+
206+
template<class ForwardIt1, class ForwardIt2>
207+
HIPSYCL_STDPAR_ENTRYPOINT
208+
std::pair<ForwardIt1, ForwardIt2> mismatch( hipsycl::stdpar::par_unseq,
209+
ForwardIt1 first1, ForwardIt1 last1,
210+
ForwardIt2 first2, ForwardIt2 last2 );
211+
212+
template<class ForwardIt1, class ForwardIt2, class BinaryPredicate>
213+
HIPSYCL_STDPAR_ENTRYPOINT
214+
std::pair<ForwardIt1, ForwardIt2> mismatch(hipsycl::stdpar::par_unseq,
215+
ForwardIt1 first1, ForwardIt1 last1,
216+
ForwardIt2 first2, ForwardIt2 last2,
217+
BinaryPredicate p);
218+
194219
template <class ForwardIt>
195220
HIPSYCL_STDPAR_ENTRYPOINT
196221
ForwardIt min_element(hipsycl::stdpar::par_unseq, ForwardIt first,
@@ -277,6 +302,31 @@ typename std::iterator_traits<ForwardIt>::difference_type
277302
count_if( hipsycl::stdpar::par, ForwardIt first, ForwardIt last,
278303
UnaryPredicate p );
279304

305+
template<class ForwardIt1, class ForwardIt2>
306+
HIPSYCL_STDPAR_ENTRYPOINT
307+
std::pair<ForwardIt1, ForwardIt2> mismatch( hipsycl::stdpar::par,
308+
ForwardIt1 first1, ForwardIt1 last1,
309+
ForwardIt2 first2 );
310+
311+
template<class ForwardIt1, class ForwardIt2, class BinaryPredicate>
312+
HIPSYCL_STDPAR_ENTRYPOINT
313+
std::pair<ForwardIt1, ForwardIt2> mismatch(hipsycl::stdpar::par,
314+
ForwardIt1 first1, ForwardIt1 last1,
315+
ForwardIt2 first2, BinaryPredicate p);
316+
317+
template<class ForwardIt1, class ForwardIt2>
318+
HIPSYCL_STDPAR_ENTRYPOINT
319+
std::pair<ForwardIt1, ForwardIt2> mismatch( hipsycl::stdpar::par,
320+
ForwardIt1 first1, ForwardIt1 last1,
321+
ForwardIt2 first2, ForwardIt2 last2 );
322+
323+
template<class ForwardIt1, class ForwardIt2, class BinaryPredicate>
324+
HIPSYCL_STDPAR_ENTRYPOINT
325+
std::pair<ForwardIt1, ForwardIt2> mismatch(hipsycl::stdpar::par,
326+
ForwardIt1 first1, ForwardIt1 last1,
327+
ForwardIt2 first2, ForwardIt2 last2,
328+
BinaryPredicate p);
329+
280330
template <class ForwardIt>
281331
HIPSYCL_STDPAR_ENTRYPOINT
282332
ForwardIt min_element(hipsycl::stdpar::par, ForwardIt first, ForwardIt last);

include/hipSYCL/std/stdpar/detail/offload_heuristic_db.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ struct any_of {};
6161
struct none_of {};
6262
struct count{};
6363
struct count_if{};
64+
struct mismatch{};
6465
struct equal {};
6566
struct sort {};
6667
struct is_sorted {};

0 commit comments

Comments
 (0)