Skip to content

Commit e5f9231

Browse files
committed
Add bfloat16 support to cbrt
1 parent fd9ebf4 commit e5f9231

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,12 +1736,8 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
17361736
17371737
// cbrt function wrapper.
17381738
template <typename ValueT>
1739-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1740-
std::is_same_v<sycl::half, ValueT>,
1741-
ValueT>
1742-
cbrt(ValueT val) {
1743-
return sycl::cbrt(static_cast<ValueT>(val));
1744-
}
1739+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
1740+
cbrt(ValueT val);
17451741
17461742
// For floating-point types, `float` or `double` arguments are acceptable.
17471743
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or

sycl/include/syclcompat/math.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,13 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a) {
721721

722722
/// cbrt function wrapper.
723723
template <typename ValueT>
724-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
725-
std::is_same_v<sycl::half, ValueT>,
726-
ValueT>
724+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
727725
cbrt(ValueT val) {
728-
return sycl::cbrt(static_cast<ValueT>(val));
726+
if constexpr (std::is_same_v<sycl::ext::oneapi::bfloat16, ValueT>) {
727+
return static_cast<ValueT>(sycl::cbrt(static_cast<float>(val)));
728+
} else {
729+
return sycl::cbrt(static_cast<ValueT>(val));
730+
}
729731
}
730732

731733
// min/max function overloads.

0 commit comments

Comments
 (0)