Skip to content

Commit 118830c

Browse files
committed
[SYCL] fix asserts after logical operation changes
1 parent 65cf13e commit 118830c

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

sycl/include/sycl/group_algorithm.hpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
215215
detail::is_native_op<T, BinaryOperation>::value),
216216
T>
217217
reduce_over_group(Group g, T x, BinaryOperation binary_op) {
218+
219+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
220+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
221+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
222+
? std::is_same_v<decltype(binary_op(x, x)), bool>
223+
: std::is_same_v<decltype(binary_op(x, x)), T>,
224+
"Result type of binary_op must match scan accumulation type.");
225+
#else
218226
static_assert(
219227
std::is_same_v<decltype(binary_op(x, x)), T>,
220228
"Result type of binary_op must match reduction accumulation type.");
229+
#endif
230+
221231
#ifdef __SYCL_DEVICE_ONLY__
222232
#if defined(__NVPTX__)
223233
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
@@ -291,9 +301,18 @@ std::enable_if_t<
291301
std::is_convertible_v<V, T>),
292302
T>
293303
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
304+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
305+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
306+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
307+
? std::is_same_v<decltype(binary_op(init, x)), bool>
308+
: std::is_same_v<decltype(binary_op(init, x)), T>,
309+
"Result type of binary_op must match scan accumulation type.");
310+
#else
294311
static_assert(
295312
std::is_same_v<decltype(binary_op(init, x)), T>,
296313
"Result type of binary_op must match reduction accumulation type.");
314+
#endif
315+
297316
#ifdef __SYCL_DEVICE_ONLY__
298317
return binary_op(init, reduce_over_group(g, T(x), binary_op));
299318
#else
@@ -341,9 +360,18 @@ std::enable_if_t<
341360
detail::is_native_op<T, BinaryOperation>::value),
342361
T>
343362
joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
363+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
364+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
365+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
366+
? std::is_same_v<decltype(binary_op(init, *first)), bool>
367+
: std::is_same_v<decltype(binary_op(init, *first)), T>,
368+
"Result type of binary_op must match scan accumulation type.");
369+
#else
344370
static_assert(
345371
std::is_same_v<decltype(binary_op(init, *first)), T>,
346372
"Result type of binary_op must match reduction accumulation type.");
373+
#endif
374+
347375
#ifdef __SYCL_DEVICE_ONLY__
348376
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
349377
sycl::detail::for_each(
@@ -679,8 +707,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
679707
detail::is_native_op<T, BinaryOperation>::value),
680708
T>
681709
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
710+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
711+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
712+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
713+
? std::is_same_v<decltype(binary_op(x, x)), bool>
714+
: std::is_same_v<decltype(binary_op(x, x)), T>,
715+
"Result type of binary_op must match scan accumulation type.");
716+
#else
682717
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
683718
"Result type of binary_op must match scan accumulation type.");
719+
#endif
684720
#ifdef __SYCL_DEVICE_ONLY__
685721
#if defined(__NVPTX__)
686722
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
@@ -752,8 +788,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
752788
detail::is_native_op<T, BinaryOperation>::value),
753789
T>
754790
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
791+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
792+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
793+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
794+
? std::is_same_v<decltype(binary_op(x, x)), bool>
795+
: std::is_same_v<decltype(binary_op(x, x)), T>,
796+
"Result type of binary_op must match scan accumulation type.");
797+
#else
755798
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
756799
"Result type of binary_op must match scan accumulation type.");
800+
#endif
757801
T result;
758802
typename detail::get_scalar_binary_op<BinaryOperation>::type
759803
scalar_binary_op{};
@@ -775,8 +819,17 @@ std::enable_if_t<
775819
std::is_convertible_v<V, T>),
776820
T>
777821
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
822+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
823+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
824+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
825+
? std::is_same_v<decltype(binary_op(init, x)), bool>
826+
: std::is_same_v<decltype(binary_op(init, x)), T>,
827+
"Result type of binary_op must match scan accumulation type.");
828+
#else
778829
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
779830
"Result type of binary_op must match scan accumulation type.");
831+
#endif
832+
780833
#ifdef __SYCL_DEVICE_ONLY__
781834
typename Group::linear_id_type local_linear_id =
782835
sycl::detail::get_local_linear_id(g);
@@ -831,8 +884,17 @@ std::enable_if_t<
831884
OutPtr>
832885
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
833886
BinaryOperation binary_op) {
887+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
888+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
889+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
890+
? std::is_same_v<decltype(binary_op(init, *first)), bool>
891+
: std::is_same_v<decltype(binary_op(init, *first)), T>,
892+
"Result type of binary_op must match scan accumulation type.");
893+
#else
834894
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
835895
"Result type of binary_op must match scan accumulation type.");
896+
#endif
897+
836898
#ifdef __SYCL_DEVICE_ONLY__
837899
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
838900
ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
@@ -883,9 +945,33 @@ std::enable_if_t<
883945
OutPtr>
884946
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
885947
BinaryOperation binary_op) {
948+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
949+
static_assert(
950+
(std::is_same_v<BinaryOperation,
951+
sycl::logical_or<std::remove_cv_t<
952+
std::remove_reference_t<decltype(*first)>>>> ||
953+
std::is_same_v<BinaryOperation,
954+
sycl::logical_and<std::remove_cv_t<
955+
std::remove_reference_t<decltype(*first)>>>>)
956+
? std::is_same_v<decltype(binary_op(
957+
std::remove_cv_t<
958+
std::remove_reference_t<decltype(*first)>>(),
959+
std::remove_cv_t<std::remove_reference_t<
960+
decltype(*first)>>())),
961+
bool>
962+
: std::is_same_v<
963+
decltype(binary_op(
964+
std::remove_cv_t<
965+
std::remove_reference_t<decltype(*first)>>(),
966+
std::remove_cv_t<
967+
std::remove_reference_t<decltype(*first)>>())),
968+
std::remove_cv_t<std::remove_reference_t<decltype(*first)>>>,
969+
"Result type of binary_op must match scan accumulation type.");
970+
#else
886971
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
887972
typename detail::remove_pointer<OutPtr>::type>,
888973
"Result type of binary_op must match scan accumulation type.");
974+
#endif
889975
using T = typename detail::remove_pointer<OutPtr>::type;
890976
T init = detail::identity_for_ga_op<T, BinaryOperation>();
891977
return joint_exclusive_scan(g, first, last, result, init, binary_op);
@@ -903,8 +989,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
903989
detail::is_native_op<T, BinaryOperation>::value),
904990
T>
905991
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
992+
993+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
994+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
995+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
996+
? std::is_same_v<decltype(binary_op(x, x)), bool>
997+
: std::is_same_v<decltype(binary_op(x, x)), T>,
998+
"Result type of binary_op must match scan accumulation type.");
999+
#else
1000+
9061001
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
9071002
"Result type of binary_op must match scan accumulation type.");
1003+
#endif
1004+
9081005
#ifdef __SYCL_DEVICE_ONLY__
9091006
#if defined(__NVPTX__)
9101007
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
@@ -972,8 +1069,18 @@ std::enable_if_t<
9721069
std::is_convertible_v<V, T>),
9731070
T>
9741071
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
1072+
1073+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1074+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
1075+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
1076+
? std::is_same_v<decltype(binary_op(init, x)), bool>
1077+
: std::is_same_v<decltype(binary_op(init, x)), T>,
1078+
"Result type of binary_op must match scan accumulation type.");
1079+
#else
9751080
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
9761081
"Result type of binary_op must match scan accumulation type.");
1082+
#endif
1083+
9771084
#ifdef __SYCL_DEVICE_ONLY__
9781085
T y = x;
9791086
if (sycl::detail::get_local_linear_id(g) == 0) {
@@ -1022,8 +1129,17 @@ std::enable_if_t<
10221129
OutPtr>
10231130
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
10241131
BinaryOperation binary_op, T init) {
1132+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1133+
static_assert((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
1134+
std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
1135+
? std::is_same_v<decltype(binary_op(init, *first)), bool>
1136+
: std::is_same_v<decltype(binary_op(init, *first)), T>,
1137+
"Result type of binary_op must match scan accumulation type.");
1138+
#else
10251139
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
10261140
"Result type of binary_op must match scan accumulation type.");
1141+
#endif
1142+
10271143
#ifdef __SYCL_DEVICE_ONLY__
10281144
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
10291145
ptrdiff_t stride = sycl::detail::get_local_linear_range(g);
@@ -1071,9 +1187,33 @@ std::enable_if_t<
10711187
OutPtr>
10721188
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
10731189
BinaryOperation binary_op) {
1190+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1191+
static_assert(
1192+
(std::is_same_v<BinaryOperation,
1193+
sycl::logical_or<std::remove_cv_t<
1194+
std::remove_reference_t<decltype(*first)>>>> ||
1195+
std::is_same_v<BinaryOperation,
1196+
sycl::logical_and<std::remove_cv_t<
1197+
std::remove_reference_t<decltype(*first)>>>>)
1198+
? std::is_same_v<decltype(binary_op(
1199+
std::remove_cv_t<
1200+
std::remove_reference_t<decltype(*first)>>(),
1201+
std::remove_cv_t<std::remove_reference_t<
1202+
decltype(*first)>>())),
1203+
bool>
1204+
: std::is_same_v<
1205+
decltype(binary_op(
1206+
std::remove_cv_t<
1207+
std::remove_reference_t<decltype(*first)>>(),
1208+
std::remove_cv_t<
1209+
std::remove_reference_t<decltype(*first)>>())),
1210+
std::remove_cv_t<std::remove_reference_t<decltype(*first)>>>,
1211+
"Result type of binary_op must match scan accumulation type.");
1212+
#else
10741213
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
10751214
typename detail::remove_pointer<OutPtr>::type>,
10761215
"Result type of binary_op must match scan accumulation type.");
1216+
#endif
10771217

10781218
using T = typename detail::remove_pointer<OutPtr>::type;
10791219
T init = detail::identity_for_ga_op<T, BinaryOperation>();

0 commit comments

Comments
 (0)