@@ -732,7 +732,7 @@ cbrt(ValueT val) {
732732// For floating-point types, `float` or `double` arguments are acceptable.
733733// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
734734// `std::int64_t` type arguments are acceptable.
735- // sycl::half supported as well.
735+ // sycl::half supported as well, and sycl::ext::oneapi::bfloat16 if available .
736736template <typename ValueT, typename ValueU>
737737inline std::enable_if_t <std::is_integral_v<ValueT> &&
738738 std::is_integral_v<ValueU>,
@@ -741,15 +741,23 @@ min(ValueT a, ValueU b) {
741741 return sycl::min (static_cast <std::common_type_t <ValueT, ValueU>>(a),
742742 static_cast <std::common_type_t <ValueT, ValueU>>(b));
743743}
744+
744745template <typename ValueT, typename ValueU>
745- inline std::enable_if_t <std ::is_floating_point_v<ValueT> &&
746- std ::is_floating_point_v<ValueU>,
746+ inline std::enable_if_t <syclcompat ::is_floating_point_v<ValueT> &&
747+ syclcompat ::is_floating_point_v<ValueU>,
747748 std::common_type_t <ValueT, ValueU>>
748749min (ValueT a, ValueU b) {
749- return sycl::fmin (static_cast <std::common_type_t <ValueT, ValueU>>(a),
750- static_cast <std::common_type_t <ValueT, ValueU>>(b));
750+ if constexpr (detail::support_bfloat16_math &&
751+ std::is_same_v<std::common_type_t <ValueT, ValueU>,
752+ sycl::ext::oneapi::bfloat16>) {
753+ return sycl::ext::oneapi::experimental::fmin (
754+ static_cast <std::common_type_t <ValueT, ValueU>>(a),
755+ static_cast <std::common_type_t <ValueT, ValueU>>(b));
756+ } else {
757+ return sycl::fmin (static_cast <std::common_type_t <ValueT, ValueU>>(a),
758+ static_cast <std::common_type_t <ValueT, ValueU>>(b));
759+ }
751760}
752- inline sycl::half min (sycl::half a, sycl::half b) { return sycl::fmin (a, b); }
753761
754762template <typename ValueT, typename ValueU>
755763inline std::enable_if_t <std::is_integral_v<ValueT> &&
@@ -760,14 +768,21 @@ max(ValueT a, ValueU b) {
760768 static_cast <std::common_type_t <ValueT, ValueU>>(b));
761769}
762770template <typename ValueT, typename ValueU>
763- inline std::enable_if_t <std ::is_floating_point_v<ValueT> &&
764- std ::is_floating_point_v<ValueU>,
771+ inline std::enable_if_t <syclcompat ::is_floating_point_v<ValueT> &&
772+ syclcompat ::is_floating_point_v<ValueU>,
765773 std::common_type_t <ValueT, ValueU>>
766774max (ValueT a, ValueU b) {
767- return sycl::fmax (static_cast <std::common_type_t <ValueT, ValueU>>(a),
768- static_cast <std::common_type_t <ValueT, ValueU>>(b));
775+ if constexpr (detail::support_bfloat16_math &&
776+ std::is_same_v<std::common_type_t <ValueT, ValueU>,
777+ sycl::ext::oneapi::bfloat16>) {
778+ return sycl::ext::oneapi::experimental::fmax (
779+ static_cast <std::common_type_t <ValueT, ValueU>>(a),
780+ static_cast <std::common_type_t <ValueT, ValueU>>(b));
781+ } else {
782+ return sycl::fmax (static_cast <std::common_type_t <ValueT, ValueU>>(a),
783+ static_cast <std::common_type_t <ValueT, ValueU>>(b));
784+ }
769785}
770- inline sycl::half max (sycl::half a, sycl::half b) { return sycl::fmax (a, b); }
771786
772787// / Performs 2 elements comparison and returns the bigger one. If either of
773788// / inputs is NaN, then return NaN.
@@ -779,16 +794,7 @@ inline std::common_type_t<ValueT, ValueU> fmax_nan(const ValueT a,
779794 const ValueU b) {
780795 if (detail::isnan (a) || detail::isnan (b))
781796 return NAN;
782- if constexpr (detail::support_bfloat16_math &&
783- std::is_same_v<std::common_type_t <ValueT, ValueU>,
784- sycl::ext::oneapi::bfloat16>) {
785- return sycl::ext::oneapi::experimental::fmax (
786- static_cast <std::common_type_t <ValueT, ValueU>>(a),
787- static_cast <std::common_type_t <ValueT, ValueU>>(b));
788- } else {
789- return sycl::fmax (static_cast <std::common_type_t <ValueT, ValueU>>(a),
790- static_cast <std::common_type_t <ValueT, ValueU>>(b));
791- }
797+ return syclcompat::max (a, b);
792798}
793799
794800template <typename ValueT, typename ValueU>
@@ -813,16 +819,7 @@ inline std::common_type_t<ValueT, ValueU> fmin_nan(const ValueT a,
813819 const ValueU b) {
814820 if (detail::isnan (a) || detail::isnan (b))
815821 return NAN;
816- if constexpr (detail::support_bfloat16_math &&
817- std::is_same_v<std::common_type_t <ValueT, ValueU>,
818- sycl::ext::oneapi::bfloat16>) {
819- return sycl::ext::oneapi::experimental::fmin (
820- static_cast <std::common_type_t <ValueT, ValueU>>(a),
821- static_cast <std::common_type_t <ValueT, ValueU>>(b));
822- } else {
823- return sycl::fmin (static_cast <std::common_type_t <ValueT, ValueU>>(a),
824- static_cast <std::common_type_t <ValueT, ValueU>>(b));
825- }
822+ return syclcompat::min (a,b);
826823}
827824
828825template <typename ValueT, typename ValueU>
0 commit comments