Skip to content

Commit f21f62e

Browse files
committed
Assert bfloat16 math support in isnan, max, and min
1 parent fc11764 commit f21f62e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ inline constexpr RetT extend_vbinary4(AT a, BT b, RetT c,
254254
}
255255

256256
template <typename ValueT> inline bool isnan(const ValueT a) {
257-
if constexpr (detail::support_bfloat16_math &&
258-
std::is_same_v<ValueT, sycl::ext::oneapi::bfloat16>) {
257+
if constexpr (std::is_same_v<ValueT, sycl::ext::oneapi::bfloat16>) {
258+
static_assert(detail::support_bfloat16_math);
259259
return sycl::ext::oneapi::experimental::isnan(a);
260260
} else {
261261
return sycl::isnan(a);
@@ -747,9 +747,9 @@ inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT> &&
747747
syclcompat::is_floating_point_v<ValueU>,
748748
std::common_type_t<ValueT, ValueU>>
749749
min(ValueT a, ValueU b) {
750-
if constexpr (detail::support_bfloat16_math &&
751-
std::is_same_v<std::common_type_t<ValueT, ValueU>,
750+
if constexpr (std::is_same_v<std::common_type_t<ValueT, ValueU>,
752751
sycl::ext::oneapi::bfloat16>) {
752+
static_assert(detail::support_bfloat16_math);
753753
return sycl::ext::oneapi::experimental::fmin(
754754
static_cast<std::common_type_t<ValueT, ValueU>>(a),
755755
static_cast<std::common_type_t<ValueT, ValueU>>(b));
@@ -772,9 +772,9 @@ inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT> &&
772772
syclcompat::is_floating_point_v<ValueU>,
773773
std::common_type_t<ValueT, ValueU>>
774774
max(ValueT a, ValueU b) {
775-
if constexpr (detail::support_bfloat16_math &&
776-
std::is_same_v<std::common_type_t<ValueT, ValueU>,
775+
if constexpr (std::is_same_v<std::common_type_t<ValueT, ValueU>,
777776
sycl::ext::oneapi::bfloat16>) {
777+
static_assert(detail::support_bfloat16_math);
778778
return sycl::ext::oneapi::experimental::fmax(
779779
static_cast<std::common_type_t<ValueT, ValueU>>(a),
780780
static_cast<std::common_type_t<ValueT, ValueU>>(b));

0 commit comments

Comments
 (0)