Skip to content

Commit d16dd56

Browse files
committed
Add max & min bfloat16 support & tests
Also reimplement `fmax_nan` & `fmin_nan` in terms of improved max/min
1 parent 51860b3 commit d16dd56

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ cbrt(ValueT val) {
732732
// For floating-point types, `float` or `double` arguments are acceptable.
733733
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
734734
// `std::int64_t` type arguments are acceptable.
735-
// sycl::half supported as well.
735+
// sycl::half supported as well, and sycl::ext::oneapi::bfloat16 if available.
736736
template <typename ValueT, typename ValueU>
737737
inline std::enable_if_t<std::is_integral_v<ValueT> &&
738738
std::is_integral_v<ValueU>,
@@ -741,15 +741,23 @@ min(ValueT a, ValueU b) {
741741
return sycl::min(static_cast<std::common_type_t<ValueT, ValueU>>(a),
742742
static_cast<std::common_type_t<ValueT, ValueU>>(b));
743743
}
744+
744745
template <typename ValueT, typename ValueU>
745-
inline std::enable_if_t<std::is_floating_point_v<ValueT> &&
746-
std::is_floating_point_v<ValueU>,
746+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT> &&
747+
syclcompat::is_floating_point_v<ValueU>,
747748
std::common_type_t<ValueT, ValueU>>
748749
min(ValueT a, ValueU b) {
749-
return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a),
750-
static_cast<std::common_type_t<ValueT, ValueU>>(b));
750+
if constexpr (detail::support_bfloat16_math &&
751+
std::is_same_v<std::common_type_t<ValueT, ValueU>,
752+
sycl::ext::oneapi::bfloat16>) {
753+
return sycl::ext::oneapi::experimental::fmin(
754+
static_cast<std::common_type_t<ValueT, ValueU>>(a),
755+
static_cast<std::common_type_t<ValueT, ValueU>>(b));
756+
} else {
757+
return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a),
758+
static_cast<std::common_type_t<ValueT, ValueU>>(b));
759+
}
751760
}
752-
inline sycl::half min(sycl::half a, sycl::half b) { return sycl::fmin(a, b); }
753761

754762
template <typename ValueT, typename ValueU>
755763
inline std::enable_if_t<std::is_integral_v<ValueT> &&
@@ -760,14 +768,21 @@ max(ValueT a, ValueU b) {
760768
static_cast<std::common_type_t<ValueT, ValueU>>(b));
761769
}
762770
template <typename ValueT, typename ValueU>
763-
inline std::enable_if_t<std::is_floating_point_v<ValueT> &&
764-
std::is_floating_point_v<ValueU>,
771+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT> &&
772+
syclcompat::is_floating_point_v<ValueU>,
765773
std::common_type_t<ValueT, ValueU>>
766774
max(ValueT a, ValueU b) {
767-
return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a),
768-
static_cast<std::common_type_t<ValueT, ValueU>>(b));
775+
if constexpr (detail::support_bfloat16_math &&
776+
std::is_same_v<std::common_type_t<ValueT, ValueU>,
777+
sycl::ext::oneapi::bfloat16>) {
778+
return sycl::ext::oneapi::experimental::fmax(
779+
static_cast<std::common_type_t<ValueT, ValueU>>(a),
780+
static_cast<std::common_type_t<ValueT, ValueU>>(b));
781+
} else {
782+
return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a),
783+
static_cast<std::common_type_t<ValueT, ValueU>>(b));
784+
}
769785
}
770-
inline sycl::half max(sycl::half a, sycl::half b) { return sycl::fmax(a, b); }
771786

772787
/// Performs 2 elements comparison and returns the bigger one. If either of
773788
/// inputs is NaN, then return NaN.
@@ -779,16 +794,7 @@ inline std::common_type_t<ValueT, ValueU> fmax_nan(const ValueT a,
779794
const ValueU b) {
780795
if (detail::isnan(a) || detail::isnan(b))
781796
return NAN;
782-
if constexpr (detail::support_bfloat16_math &&
783-
std::is_same_v<std::common_type_t<ValueT, ValueU>,
784-
sycl::ext::oneapi::bfloat16>) {
785-
return sycl::ext::oneapi::experimental::fmax(
786-
static_cast<std::common_type_t<ValueT, ValueU>>(a),
787-
static_cast<std::common_type_t<ValueT, ValueU>>(b));
788-
} else {
789-
return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a),
790-
static_cast<std::common_type_t<ValueT, ValueU>>(b));
791-
}
797+
return syclcompat::max(a, b);
792798
}
793799

794800
template <typename ValueT, typename ValueU>
@@ -813,16 +819,7 @@ inline std::common_type_t<ValueT, ValueU> fmin_nan(const ValueT a,
813819
const ValueU b) {
814820
if (detail::isnan(a) || detail::isnan(b))
815821
return NAN;
816-
if constexpr (detail::support_bfloat16_math &&
817-
std::is_same_v<std::common_type_t<ValueT, ValueU>,
818-
sycl::ext::oneapi::bfloat16>) {
819-
return sycl::ext::oneapi::experimental::fmin(
820-
static_cast<std::common_type_t<ValueT, ValueU>>(a),
821-
static_cast<std::common_type_t<ValueT, ValueU>>(b));
822-
} else {
823-
return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a),
824-
static_cast<std::common_type_t<ValueT, ValueU>>(b));
825-
}
822+
return syclcompat::min(a,b);
826823
}
827824

828825
template <typename ValueT, typename ValueU>

sycl/test-e2e/syclcompat/math/math_ops.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
template <typename ValueT, typename ValueU>
3535
inline void max_kernel(ValueT *a, ValueU *b,
3636
std::common_type_t<ValueT, ValueU> *r) {
37-
*r = syclcompat::max(*a, *b);
37+
*r = syclcompat::max<ValueT, ValueU>(*a, *b);
3838
}
3939

4040
template <typename ValueT, typename ValueU = ValueT>
@@ -54,7 +54,7 @@ void test_syclcompat_max() {
5454
template <typename ValueT, typename ValueU>
5555
inline void min_kernel(ValueT *a, ValueU *b,
5656
std::common_type_t<ValueT, ValueU> *r) {
57-
*r = syclcompat::min(*a, *b);
57+
*r = syclcompat::min<ValueT,ValueU>(*a, *b);
5858
}
5959

6060
template <typename ValueT, typename ValueU = ValueT>
@@ -342,8 +342,15 @@ int main() {
342342
// Basic testing of deduction to avoid combinatorial explosion
343343
test_syclcompat_max<double, float>();
344344
test_syclcompat_max<long, int>();
345+
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
346+
test_syclcompat_max<sycl::ext::oneapi::bfloat16, float>();
347+
#endif
348+
345349
test_syclcompat_min<double, float>();
346350
test_syclcompat_min<long, int>();
351+
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
352+
test_syclcompat_min<sycl::ext::oneapi::bfloat16, float>();
353+
#endif
347354

348355
INSTANTIATE_ALL_TYPES(fp_type_list, test_syclcompat_fmin_nan);
349356
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_container_syclcompat_fmin_nan);

0 commit comments

Comments
 (0)