Skip to content

Commit b0303bb

Browse files
committed
Revert unneeded cbrt<bfloat16> support
1 parent ac2b4e9 commit b0303bb

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1795,7 +1795,9 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
17951795
17961796
// cbrt function wrapper.
17971797
template <typename ValueT>
1798-
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
1798+
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1799+
std::is_same_v<ValueT, sycl::half>,
1800+
ValueT>
17991801
cbrt(ValueT val);
18001802
18011803
// For floating-point types, `float` or `double` arguments are acceptable.

sycl/include/syclcompat/math.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -721,13 +721,11 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a) {
721721

722722
/// cbrt function wrapper.
723723
template <typename ValueT>
724-
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
724+
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
725+
std::is_same_v<ValueT, sycl::half>,
726+
ValueT>
725727
cbrt(ValueT val) {
726-
if constexpr (std::is_same_v<sycl::ext::oneapi::bfloat16, ValueT>) {
727-
return static_cast<ValueT>(sycl::cbrt(static_cast<float>(val)));
728-
} else {
729-
return sycl::cbrt(static_cast<ValueT>(val));
730-
}
728+
return sycl::cbrt(static_cast<ValueT>(val));
731729
}
732730

733731
// min/max function overloads.

sycl/test-e2e/syclcompat/common.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,14 @@ using value_type_list =
6363
#endif
6464
>;
6565

66-
using fp_type_list =
67-
std::tuple<float, double, sycl::half, sycl::ext::oneapi::bfloat16>;
66+
using fp_type_list_no_bfloat16 = std::tuple<float, double, sycl::half>;
67+
68+
using fp_type_list = std::tuple<float, double, sycl::half
69+
70+
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
71+
,sycl::ext::oneapi::bfloat16
72+
#endif
73+
>;
6874

6975
using marray_type_list =
7076
std::tuple<char, signed char, short, int, long, long long, unsigned char,

sycl/test-e2e/syclcompat/math/math_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ int main() {
375375
test_syclcompat_pow<double, int>();
376376

377377
INSTANTIATE_ALL_TYPES(fp_type_list, test_syclcompat_relu);
378-
INSTANTIATE_ALL_TYPES(fp_type_list, test_syclcompat_cbrt);
378+
INSTANTIATE_ALL_TYPES(fp_type_list_no_bfloat16, test_syclcompat_cbrt);
379379

380380
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_isnan);
381381
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_isnan);

0 commit comments

Comments
 (0)