@@ -215,9 +215,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
215215 detail::is_native_op<T, BinaryOperation>::value),
216216 T>
217217reduce_over_group (Group g, T x, BinaryOperation binary_op) {
218+
219+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
220+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
221+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
222+ ? std::is_same_v<decltype (binary_op (x, x)), bool >
223+ : std::is_same_v<decltype (binary_op (x, x)), T>,
224+ " Result type of binary_op must match scan accumulation type." );
225+ #else
218226 static_assert (
219227 std::is_same_v<decltype (binary_op (x, x)), T>,
220228 " Result type of binary_op must match reduction accumulation type." );
229+ #endif
230+
221231#ifdef __SYCL_DEVICE_ONLY__
222232#if defined(__NVPTX__)
223233 if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
@@ -291,9 +301,18 @@ std::enable_if_t<
291301 std::is_convertible_v<V, T>),
292302 T>
293303reduce_over_group (Group g, V x, T init, BinaryOperation binary_op) {
304+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
305+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
306+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
307+ ? std::is_same_v<decltype (binary_op (init, x)), bool >
308+ : std::is_same_v<decltype (binary_op (init, x)), T>,
309+ " Result type of binary_op must match scan accumulation type." );
310+ #else
294311 static_assert (
295312 std::is_same_v<decltype (binary_op (init, x)), T>,
296313 " Result type of binary_op must match reduction accumulation type." );
314+ #endif
315+
297316#ifdef __SYCL_DEVICE_ONLY__
298317 return binary_op (init, reduce_over_group (g, T (x), binary_op));
299318#else
@@ -341,9 +360,18 @@ std::enable_if_t<
341360 detail::is_native_op<T, BinaryOperation>::value),
342361 T>
343362joint_reduce (Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
363+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
364+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
365+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
366+ ? std::is_same_v<decltype (binary_op (init, *first)), bool >
367+ : std::is_same_v<decltype (binary_op (init, *first)), T>,
368+ " Result type of binary_op must match scan accumulation type." );
369+ #else
344370 static_assert (
345371 std::is_same_v<decltype (binary_op (init, *first)), T>,
346372 " Result type of binary_op must match reduction accumulation type." );
373+ #endif
374+
347375#ifdef __SYCL_DEVICE_ONLY__
348376 T partial = detail::identity_for_ga_op<T, BinaryOperation>();
349377 sycl::detail::for_each (
@@ -679,8 +707,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
679707 detail::is_native_op<T, BinaryOperation>::value),
680708 T>
681709exclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
710+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
711+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
712+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
713+ ? std::is_same_v<decltype (binary_op (x, x)), bool >
714+ : std::is_same_v<decltype (binary_op (x, x)), T>,
715+ " Result type of binary_op must match scan accumulation type." );
716+ #else
682717 static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
683718 " Result type of binary_op must match scan accumulation type." );
719+ #endif
684720#ifdef __SYCL_DEVICE_ONLY__
685721#if defined(__NVPTX__)
686722 if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
@@ -752,8 +788,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
752788 detail::is_native_op<T, BinaryOperation>::value),
753789 T>
754790exclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
791+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
792+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
793+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
794+ ? std::is_same_v<decltype (binary_op (x, x)), bool >
795+ : std::is_same_v<decltype (binary_op (x, x)), T>,
796+ " Result type of binary_op must match scan accumulation type." );
797+ #else
755798 static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
756799 " Result type of binary_op must match scan accumulation type." );
800+ #endif
757801 T result;
758802 typename detail::get_scalar_binary_op<BinaryOperation>::type
759803 scalar_binary_op{};
@@ -775,8 +819,17 @@ std::enable_if_t<
775819 std::is_convertible_v<V, T>),
776820 T>
777821exclusive_scan_over_group (Group g, V x, T init, BinaryOperation binary_op) {
822+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
823+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
824+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
825+ ? std::is_same_v<decltype (binary_op (init, x)), bool >
826+ : std::is_same_v<decltype (binary_op (init, x)), T>,
827+ " Result type of binary_op must match scan accumulation type." );
828+ #else
778829 static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
779830 " Result type of binary_op must match scan accumulation type." );
831+ #endif
832+
780833#ifdef __SYCL_DEVICE_ONLY__
781834 typename Group::linear_id_type local_linear_id =
782835 sycl::detail::get_local_linear_id (g);
@@ -831,8 +884,17 @@ std::enable_if_t<
831884 OutPtr>
832885joint_exclusive_scan (Group g, InPtr first, InPtr last, OutPtr result, T init,
833886 BinaryOperation binary_op) {
887+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
888+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
889+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
890+ ? std::is_same_v<decltype (binary_op (init, *first)), bool >
891+ : std::is_same_v<decltype (binary_op (init, *first)), T>,
892+ " Result type of binary_op must match scan accumulation type." );
893+ #else
834894 static_assert (std::is_same_v<decltype (binary_op (init, *first)), T>,
835895 " Result type of binary_op must match scan accumulation type." );
896+ #endif
897+
836898#ifdef __SYCL_DEVICE_ONLY__
837899 ptrdiff_t offset = sycl::detail::get_local_linear_id (g);
838900 ptrdiff_t stride = sycl::detail::get_local_linear_range (g);
@@ -883,9 +945,33 @@ std::enable_if_t<
883945 OutPtr>
884946joint_exclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
885947 BinaryOperation binary_op) {
948+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
949+ static_assert (
950+ (std::is_same_v<BinaryOperation,
951+ sycl::logical_or<std::remove_cv_t <
952+ std::remove_reference_t <decltype (*first)>>>> ||
953+ std::is_same_v<BinaryOperation,
954+ sycl::logical_and<std::remove_cv_t <
955+ std::remove_reference_t <decltype (*first)>>>>)
956+ ? std::is_same_v<decltype (binary_op (
957+ std::remove_cv_t <
958+ std::remove_reference_t <decltype (*first)>>(),
959+ std::remove_cv_t <std::remove_reference_t <
960+ decltype (*first)>>())),
961+ bool >
962+ : std::is_same_v<
963+ decltype (binary_op (
964+ std::remove_cv_t <
965+ std::remove_reference_t <decltype (*first)>>(),
966+ std::remove_cv_t <
967+ std::remove_reference_t <decltype (*first)>>())),
968+ std::remove_cv_t <std::remove_reference_t <decltype (*first)>>>,
969+ " Result type of binary_op must match scan accumulation type." );
970+ #else
886971 static_assert (std::is_same_v<decltype (binary_op (*first, *first)),
887972 typename detail::remove_pointer<OutPtr>::type>,
888973 " Result type of binary_op must match scan accumulation type." );
974+ #endif
889975 using T = typename detail::remove_pointer<OutPtr>::type;
890976 T init = detail::identity_for_ga_op<T, BinaryOperation>();
891977 return joint_exclusive_scan (g, first, last, result, init, binary_op);
@@ -903,8 +989,19 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
903989 detail::is_native_op<T, BinaryOperation>::value),
904990 T>
905991inclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
992+
993+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
994+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
995+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
996+ ? std::is_same_v<decltype (binary_op (x, x)), bool >
997+ : std::is_same_v<decltype (binary_op (x, x)), T>,
998+ " Result type of binary_op must match scan accumulation type." );
999+ #else
1000+
9061001 static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
9071002 " Result type of binary_op must match scan accumulation type." );
1003+ #endif
1004+
9081005#ifdef __SYCL_DEVICE_ONLY__
9091006#if defined(__NVPTX__)
9101007 if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<Group>) {
@@ -972,8 +1069,18 @@ std::enable_if_t<
9721069 std::is_convertible_v<V, T>),
9731070 T>
9741071inclusive_scan_over_group (Group g, V x, BinaryOperation binary_op, T init) {
1072+
1073+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1074+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
1075+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
1076+ ? std::is_same_v<decltype (binary_op (init, x)), bool >
1077+ : std::is_same_v<decltype (binary_op (init, x)), T>,
1078+ " Result type of binary_op must match scan accumulation type." );
1079+ #else
9751080 static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
9761081 " Result type of binary_op must match scan accumulation type." );
1082+ #endif
1083+
9771084#ifdef __SYCL_DEVICE_ONLY__
9781085 T y = x;
9791086 if (sycl::detail::get_local_linear_id (g) == 0 ) {
@@ -1022,8 +1129,17 @@ std::enable_if_t<
10221129 OutPtr>
10231130joint_inclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
10241131 BinaryOperation binary_op, T init) {
1132+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1133+ static_assert ((std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
1134+ std::is_same_v<BinaryOperation, sycl::logical_and<T>>)
1135+ ? std::is_same_v<decltype (binary_op (init, *first)), bool >
1136+ : std::is_same_v<decltype (binary_op (init, *first)), T>,
1137+ " Result type of binary_op must match scan accumulation type." );
1138+ #else
10251139 static_assert (std::is_same_v<decltype (binary_op (init, *first)), T>,
10261140 " Result type of binary_op must match scan accumulation type." );
1141+ #endif
1142+
10271143#ifdef __SYCL_DEVICE_ONLY__
10281144 ptrdiff_t offset = sycl::detail::get_local_linear_id (g);
10291145 ptrdiff_t stride = sycl::detail::get_local_linear_range (g);
@@ -1071,9 +1187,33 @@ std::enable_if_t<
10711187 OutPtr>
10721188joint_inclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
10731189 BinaryOperation binary_op) {
1190+ #ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1191+ static_assert (
1192+ (std::is_same_v<BinaryOperation,
1193+ sycl::logical_or<std::remove_cv_t <
1194+ std::remove_reference_t <decltype (*first)>>>> ||
1195+ std::is_same_v<BinaryOperation,
1196+ sycl::logical_and<std::remove_cv_t <
1197+ std::remove_reference_t <decltype (*first)>>>>)
1198+ ? std::is_same_v<decltype (binary_op (
1199+ std::remove_cv_t <
1200+ std::remove_reference_t <decltype (*first)>>(),
1201+ std::remove_cv_t <std::remove_reference_t <
1202+ decltype (*first)>>())),
1203+ bool >
1204+ : std::is_same_v<
1205+ decltype (binary_op (
1206+ std::remove_cv_t <
1207+ std::remove_reference_t <decltype (*first)>>(),
1208+ std::remove_cv_t <
1209+ std::remove_reference_t <decltype (*first)>>())),
1210+ std::remove_cv_t <std::remove_reference_t <decltype (*first)>>>,
1211+ " Result type of binary_op must match scan accumulation type." );
1212+ #else
10741213 static_assert (std::is_same_v<decltype (binary_op (*first, *first)),
10751214 typename detail::remove_pointer<OutPtr>::type>,
10761215 " Result type of binary_op must match scan accumulation type." );
1216+ #endif
10771217
10781218 using T = typename detail::remove_pointer<OutPtr>::type;
10791219 T init = detail::identity_for_ga_op<T, BinaryOperation>();
0 commit comments