Skip to content

Commit 5b0e0b5

Browse files
committed
Enable relu for bfloat
1 parent 55e95ad commit 5b0e0b5

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -838,24 +838,20 @@ pow(const ValueT a, const ValueU b) {
838838
/// \param [in] a The input value
839839
/// \returns the relu saturation result
840840
template <typename ValueT>
841-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
842-
std::is_same_v<sycl::half, ValueT>,
843-
ValueT>
841+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
844842
relu(const ValueT a) {
845843
if (!detail::isnan(a) && a < ValueT(0))
846844
return ValueT(0);
847845
return a;
848846
}
849847
template <class ValueT>
850-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
851-
std::is_same_v<sycl::half, ValueT>,
848+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>,
852849
sycl::vec<ValueT, 2>>
853850
relu(const sycl::vec<ValueT, 2> a) {
854851
return {relu(a[0]), relu(a[1])};
855852
}
856853
template <class ValueT>
857-
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
858-
std::is_same_v<sycl::half, ValueT>,
854+
inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>,
859855
sycl::marray<ValueT, 2>>
860856
relu(const sycl::marray<ValueT, 2> a) {
861857
return {relu(a[0]), relu(a[1])};

0 commit comments

Comments
 (0)