Skip to content

Commit 1791115

Browse files
authored
[SYCL][COMPAT] Add bfloat16 support to several maths ops (#15572)
Adds support for `sycl::ext::oneapi::bfloat16` to: - `relu` - `clamp` - `fmax_nan` - `fmin_nan` - `min` - `max` - `compare_mask` - `unordered_compare_mask`
1 parent 3796776 commit 1791115

File tree

9 files changed

+498
-164
lines changed

9 files changed

+498
-164
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,7 +1729,51 @@ second operand, respectively. These three APIs return a single 32-bit value with
17291729
the accumulated result, which is unsigned if both operands are `uint32_t` and
17301730
signed otherwise.
17311731

1732+
Various maths functions are defined operate on any floating point types.
1733+
`syclcompat::is_floating_point_v` extends the standard library's
1734+
`std::is_floating_point_v` to include `sycl::half` and, where available,
1735+
`sycl::ext::oneapi::bfloat16`. The current version of SYCLcompat also provides
1736+
a specialization of `std::common_type_t` for `sycl::ext::oneapi::bfloat16`,
1737+
though this will be moved to the `sycl_ext_oneapi_bfloat16` extension in
1738+
future.
1739+
1740+
```cpp
1741+
namespace std {
1742+
template <> struct common_type<sycl::ext::oneapi::bfloat16> {
1743+
using type = sycl::ext::oneapi::bfloat16;
1744+
};
1745+
1746+
template <>
1747+
struct common_type<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16> {
1748+
using type = sycl::ext::oneapi::bfloat16;
1749+
};
1750+
1751+
template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
1752+
using type = sycl::ext::oneapi::bfloat16;
1753+
};
1754+
1755+
template <typename T> struct common_type<T, sycl::ext::oneapi::bfloat16> {
1756+
using type = sycl::ext::oneapi::bfloat16;
1757+
};
1758+
} // namespace std
1759+
```
1760+
17321761
```cpp
1762+
namespace syclcompat{
1763+
1764+
// Trait for extended floating point definition
1765+
template <typename T>
1766+
struct is_floating_point : std::is_floating_point<T>{};
1767+
1768+
template <> struct is_floating_point<sycl::half> : std::true_type {};
1769+
1770+
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
1771+
template <> struct is_floating_point<sycl::ext::oneapi::bfloat16> : std::true_type {};
1772+
#endif
1773+
template <typename T>
1774+
1775+
inline constexpr bool is_floating_point_v = is_floating_point<T>::value;
1776+
17331777
inline unsigned int funnelshift_l(unsigned int low, unsigned int high,
17341778
unsigned int shift);
17351779
@@ -1756,11 +1800,9 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
17561800
// cbrt function wrapper.
17571801
template <typename ValueT>
17581802
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1759-
std::is_same_v<sycl::half, ValueT>,
1803+
std::is_same_v<ValueT, sycl::half>,
17601804
ValueT>
1761-
cbrt(ValueT val) {
1762-
return sycl::cbrt(static_cast<ValueT>(val));
1763-
}
1805+
cbrt(ValueT val);
17641806
17651807
// For floating-point types, `float` or `double` arguments are acceptable.
17661808
// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
@@ -1798,6 +1840,10 @@ template <typename ValueT, typename ValueU>
17981840
inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
17991841
fmax_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
18001842
1843+
template <typename ValueT, typename ValueU>
1844+
inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
1845+
fmax_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
1846+
18011847
// Performs 2 elements comparison and returns the smaller one. If either of
18021848
// inputs is NaN, then return NaN.
18031849
template <typename ValueT, typename ValueU>
@@ -1807,6 +1853,10 @@ template <typename ValueT, typename ValueU>
18071853
inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
18081854
fmin_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
18091855
1856+
template <typename ValueT, typename ValueU>
1857+
inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
1858+
fmin_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
1859+
18101860
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
18111861
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
18121862
@@ -1867,14 +1917,13 @@ unordered_compare_both(const ValueT a, const ValueT b,
18671917
const BinaryOperation binary_op);
18681918
18691919
template <typename ValueT, class BinaryOperation>
1870-
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
1871-
const sycl::vec<ValueT, 2> b,
1872-
const BinaryOperation binary_op);
1920+
inline std::enable_if_t<ValueT::size() == 2, unsigned>
1921+
compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op);
18731922
18741923
template <typename ValueT, class BinaryOperation>
1875-
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
1876-
const sycl::vec<ValueT, 2> b,
1877-
const BinaryOperation binary_op);
1924+
inline std::enable_if_t<ValueT::size() == 2, unsigned>
1925+
unordered_compare_mask(const ValueT a, const ValueT b,
1926+
const BinaryOperation binary_op);
18781927
18791928
template <typename S, typename T> inline T vectorized_max(T a, T b);
18801929
@@ -1928,6 +1977,7 @@ inline dot_product_acc_t<T1, T2> dp2a_hi(T1 a, T2 b,
19281977
template <typename T1, typename T2>
19291978
inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,
19301979
dot_product_acc_t<T1, T2> c);
1980+
} // namespace syclcompat
19311981
```
19321982
19331983
`vectorized_binary` computes the `BinaryOperation` for two operands,

0 commit comments

Comments
 (0)