Skip to content

Commit 0b43551

Browse files
authored
oneDPL: Mark comparators as device-copyable (kokkos#7538)
* oneDPL: Mark comparators as device-copyable * Narrow down guards * Compatibility with MSVC * Indentation * Fix * Indentation * Fix * Indentation * Indentation * Always define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL * Update version check to 2022.8.0 since no release can detect 2022.7.1 * Remove oneDPL workaround * Remove redundant check * Apply suggestions from code review * Indentation
1 parent e5180b7 commit 0b43551

File tree

4 files changed

+120
-22
lines changed

4 files changed

+120
-22
lines changed

algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
#endif
8888
#endif
8989

90+
#ifndef KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL
91+
#define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(MAJOR, MINOR, PATCH) 0
92+
#endif
93+
9094
namespace Kokkos::Impl {
9195

9296
template <typename T>
@@ -156,7 +160,7 @@ void sort_by_key_rocthrust(
156160

157161
#if defined(KOKKOS_ENABLE_ONEDPL)
158162

159-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
163+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
160164
template <class Layout>
161165
inline constexpr bool sort_on_device_v<Kokkos::SYCL, Layout> = true;
162166
#else
@@ -176,11 +180,11 @@ void sort_by_key_onedpl(
176180
MaybeComparator&&... maybeComparator) {
177181
auto queue = exec.sycl_queue();
178182
auto policy = oneapi::dpl::execution::make_device_policy(queue);
179-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
180-
oneapi::dpl::sort_by_key(policy, ::Kokkos::Experimental::begin(keys),
181-
::Kokkos::Experimental::end(keys),
182-
::Kokkos::Experimental::begin(values),
183-
std::forward<MaybeComparator>(maybeComparator)...);
183+
184+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
185+
auto keys_begin = ::Kokkos::Experimental::begin(keys);
186+
auto keys_end = ::Kokkos::Experimental::end(keys);
187+
auto values_begin = ::Kokkos::Experimental::begin(values);
184188
#else
185189
if (keys.stride(0) != 1 && values.stride(0) != 1) {
186190
Kokkos::abort(
@@ -189,10 +193,24 @@ void sort_by_key_onedpl(
189193

190194
// Can't use Experimental::begin/end here since the oneDPL then assumes that
191195
// the data is on the host.
192-
const int n = keys.extent(0);
193-
oneapi::dpl::sort_by_key(policy, keys.data(), keys.data() + n, values.data(),
194-
std::forward<MaybeComparator>(maybeComparator)...);
196+
const int n = keys.extent(0);
197+
auto keys_begin = keys.data();
198+
auto keys_end = keys.data() + n;
199+
auto values_begin = values.data();
195200
#endif
201+
202+
if constexpr (sizeof...(MaybeComparator) == 0)
203+
oneapi::dpl::sort_by_key(policy, keys_begin, keys_end, values_begin);
204+
else {
205+
using keys_value_type =
206+
typename Kokkos::View<KeysDataType, KeysProperties...>::value_type;
207+
auto keys_comparator =
208+
std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
209+
oneapi::dpl::sort_by_key(
210+
policy, keys_begin, keys_end, values_begin,
211+
ComparatorWrapper<decltype(keys_comparator), keys_value_type>{
212+
keys_comparator});
213+
}
196214
}
197215
#endif
198216
#endif
@@ -292,7 +310,8 @@ void sort_by_key_via_sort(
292310
host_exec.fence("Kokkos::Impl::sort_by_key_via_sort: after host sort");
293311
Kokkos::deep_copy(exec, permute, host_permute);
294312
} else {
295-
#ifdef KOKKOS_ENABLE_SYCL
313+
#if defined(KOKKOS_IMPL_ONEDPL_HAS_SORT_BY_KEY) && \
314+
!KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
296315
auto* raw_keys_in_comparator = keys.data();
297316
auto stride = keys.stride(0);
298317
if constexpr (sizeof...(MaybeComparator) == 0) {
@@ -364,7 +383,7 @@ void sort_by_key_device_view_without_comparator(
364383
const Kokkos::View<KeysDataType, KeysProperties...>& keys,
365384
const Kokkos::View<ValuesDataType, ValuesProperties...>& values) {
366385
#ifdef KOKKOS_IMPL_ONEDPL_HAS_SORT_BY_KEY
367-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
386+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
368387
sort_by_key_onedpl(exec, keys, values);
369388
#else
370389
if (keys.stride(0) == 1 && values.stride(0) == 1)
@@ -428,7 +447,7 @@ void sort_by_key_device_view_with_comparator(
428447
const Kokkos::View<ValuesDataType, ValuesProperties...>& values,
429448
const ComparatorType& comparator) {
430449
#ifdef KOKKOS_IMPL_ONEDPL_HAS_SORT_BY_KEY
431-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
450+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
432451
sort_by_key_onedpl(exec, keys, values, comparator);
433452
#else
434453
if (keys.stride(0) == 1 && values.stride(0) == 1)

algorithms/src/sorting/impl/Kokkos_SortImpl.hpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@
8686
ONEDPL_VERSION_PATCH
8787
#define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(MAJOR, MINOR, PATCH) \
8888
(KOKKOS_IMPL_ONEDPL_VERSION >= ((MAJOR)*10000 + (MINOR)*100 + (PATCH)))
89+
90+
namespace Kokkos::Impl {
91+
template <typename Comparator, typename ValueType>
92+
struct ComparatorWrapper {
93+
Comparator comparator;
94+
KOKKOS_FUNCTION bool operator()(const ValueType& i,
95+
const ValueType& j) const {
96+
return comparator(i, j);
97+
}
98+
};
99+
} // namespace Kokkos::Impl
100+
101+
template <typename Comparator, typename ValueType>
102+
struct sycl::is_device_copyable<
103+
Kokkos::Impl::ComparatorWrapper<Comparator, ValueType>> : std::true_type {};
89104
#endif
90105

91106
namespace Kokkos {
@@ -235,7 +250,7 @@ void sort_onedpl(const Kokkos::SYCL& space,
235250
"SYCL execution space is not able to access the memory space "
236251
"of the View argument!");
237252

238-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
253+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
239254
static_assert(ViewType::rank == 1,
240255
"Kokkos::sort currently only supports rank-1 Views.");
241256
#else
@@ -261,17 +276,28 @@ void sort_onedpl(const Kokkos::SYCL& space,
261276
auto queue = space.sycl_queue();
262277
auto policy = oneapi::dpl::execution::make_device_policy(queue);
263278

264-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
265-
oneapi::dpl::sort(policy, ::Kokkos::Experimental::begin(view),
266-
::Kokkos::Experimental::end(view),
267-
std::forward<MaybeComparator>(maybeComparator)...);
279+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
280+
auto view_begin = ::Kokkos::Experimental::begin(view);
281+
auto view_end = ::Kokkos::Experimental::end(view);
268282
#else
269283
// Can't use Experimental::begin/end here since the oneDPL then assumes that
270284
// the data is on the host.
271-
const int n = view.extent(0);
272-
oneapi::dpl::sort(policy, view.data(), view.data() + n,
273-
std::forward<MaybeComparator>(maybeComparator)...);
285+
const int n = view.extent(0);
286+
auto view_begin = view.data();
287+
auto view_end = view.data() + n;
274288
#endif
289+
290+
if constexpr (sizeof...(MaybeComparator) == 0)
291+
oneapi::dpl::sort(policy, view_begin, view_end);
292+
else {
293+
using value_type =
294+
typename Kokkos::View<DataType, Properties...>::value_type;
295+
auto comparator =
296+
std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
297+
oneapi::dpl::sort(
298+
policy, view_begin, view_end,
299+
ComparatorWrapper<decltype(comparator), value_type>{comparator});
300+
}
275301
}
276302
#endif
277303

@@ -348,7 +374,7 @@ void sort_device_view_without_comparator(
348374
"sort_device_view_without_comparator: supports rank-1 Views "
349375
"with LayoutLeft, LayoutRight or LayoutStride");
350376

351-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
377+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
352378
sort_onedpl(exec, view);
353379
#else
354380
if (view.stride(0) == 1) {
@@ -407,7 +433,7 @@ void sort_device_view_with_comparator(
407433
"sort_device_view_with_comparator: supports rank-1 Views "
408434
"with LayoutLeft, LayoutRight or LayoutStride");
409435

410-
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 7, 1)
436+
#if KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
411437
sort_onedpl(exec, view, comparator);
412438
#else
413439
if (view.stride(0) == 1) {

algorithms/unit_tests/TestSortByKey.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,29 @@
2424

2525
#include <utility> // pair
2626

27+
#if defined(KOKKOS_ENABLE_ONEDPL)
28+
#define KOKKOS_IMPL_ONEDPL_VERSION \
29+
ONEDPL_VERSION_MAJOR * 10000 + ONEDPL_VERSION_MINOR * 100 + \
30+
ONEDPL_VERSION_PATCH
31+
#define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(MAJOR, MINOR, PATCH) \
32+
(KOKKOS_IMPL_ONEDPL_VERSION >= ((MAJOR)*10000 + (MINOR)*100 + (PATCH)))
33+
#endif
34+
35+
#ifndef KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL
36+
#define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(MAJOR, MINOR, PATH) 0
37+
#endif
38+
2739
namespace Test {
2840
namespace SortImpl {
2941

3042
struct Less {
43+
#if !defined(KOKKOS_ENABLE_ONEDPL) || \
44+
KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
45+
// Test with a comparator that isn't trivially copyable if oneDPL is not
46+
// enabled or if oneDPL version >= 2022.8.0
47+
Kokkos::View<int *> dummy;
48+
#endif
49+
3150
template <class ValueType>
3251
KOKKOS_INLINE_FUNCTION bool operator()(const ValueType &lhs,
3352
const ValueType &rhs) const {
@@ -36,6 +55,13 @@ struct Less {
3655
};
3756

3857
struct Greater {
58+
#if !defined(KOKKOS_ENABLE_ONEDPL) || \
59+
KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
60+
// Test with a comparator that isn't trivially copyable if oneDPL is not
61+
// enabled or if oneDPL version >= 2022.8.0
62+
Kokkos::View<int *> dummy;
63+
#endif
64+
3965
template <class ValueType>
4066
KOKKOS_INLINE_FUNCTION bool operator()(const ValueType &lhs,
4167
const ValueType &rhs) const {
@@ -252,4 +278,8 @@ TEST(TEST_CATEGORY_DEATH, SortByKeyKeysLargerThanValues) {
252278
}
253279

254280
} // namespace Test
281+
282+
#undef KOKKOS_IMPL_ONEDPL_VERSION
283+
#undef KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL
284+
255285
#endif

algorithms/unit_tests/TestSortCustomComp.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
#include <Kokkos_Sort.hpp>
2424
#include <TestStdAlgorithmsCommon.hpp>
2525

26+
#if defined(KOKKOS_ENABLE_ONEDPL)
27+
#define KOKKOS_IMPL_ONEDPL_VERSION \
28+
ONEDPL_VERSION_MAJOR * 10000 + ONEDPL_VERSION_MINOR * 100 + \
29+
ONEDPL_VERSION_PATCH
30+
#define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(MAJOR, MINOR, PATCH) \
31+
(KOKKOS_IMPL_ONEDPL_VERSION >= ((MAJOR)*10000 + (MINOR)*100 + (PATCH)))
32+
#endif
33+
34+
#ifndef KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL
35+
#define KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(MAJOR, MINOR, PATH) 0
36+
#endif
37+
2638
namespace {
2739
namespace SortWithComp {
2840

@@ -62,6 +74,13 @@ auto create_random_view_and_host_clone(
6274

6375
template <class T>
6476
struct MyComp {
77+
#if !defined(KOKKOS_ENABLE_ONEDPL) || \
78+
KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL(2022, 8, 0)
79+
// Make sure that the comparator isn't device copyable, this caused problems
80+
// with SYCL/oneDPL
81+
Kokkos::View<T*> dummy;
82+
#endif
83+
6584
KOKKOS_FUNCTION
6685
bool operator()(T a, T b) const {
6786
// we return a>b on purpose here, rather than doing a<b
@@ -130,4 +149,8 @@ TEST(TEST_CATEGORY, SortWithCustomComparator) {
130149

131150
} // namespace SortWithComp
132151
} // namespace anonym
152+
153+
#undef KOKKOS_IMPL_ONEDPL_VERSION
154+
#undef KOKKOS_IMPL_ONEDPL_VERSION_GREATER_EQUAL
155+
133156
#endif

0 commit comments

Comments
 (0)