@@ -1729,7 +1729,51 @@ second operand, respectively. These three APIs return a single 32-bit value with
17291729the accumulated result, which is unsigned if both operands are `uint32_t ` and
17301730signed otherwise.
17311731
1732+ Various maths functions are defined operate on any floating point types.
1733+ `syclcompat::is_floating_point_v` extends the standard library' s
1734+ `std::is_floating_point_v` to include `sycl::half` and, where available,
1735+ `sycl::ext::oneapi::bfloat16`. The current version of SYCLcompat also provides
1736+ a specialization of `std::common_type_t` for `sycl::ext::oneapi::bfloat16`,
1737+ though this will be moved to the `sycl_ext_oneapi_bfloat16` extension in
1738+ future.
1739+
1740+ ```cpp
1741+ namespace std {
1742+ template <> struct common_type<sycl::ext::oneapi::bfloat16> {
1743+ using type = sycl::ext::oneapi::bfloat16;
1744+ };
1745+
1746+ template <>
1747+ struct common_type<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16> {
1748+ using type = sycl::ext::oneapi::bfloat16;
1749+ };
1750+
1751+ template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
1752+ using type = sycl::ext::oneapi::bfloat16;
1753+ };
1754+
1755+ template <typename T> struct common_type<T, sycl::ext::oneapi::bfloat16> {
1756+ using type = sycl::ext::oneapi::bfloat16;
1757+ };
1758+ } // namespace std
1759+ ```
1760+
17321761```cpp
1762+ namespace syclcompat{
1763+
1764+ // Trait for extended floating point definition
1765+ template <typename T>
1766+ struct is_floating_point : std::is_floating_point<T>{};
1767+
1768+ template <> struct is_floating_point<sycl::half> : std::true_type {};
1769+
1770+ #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
1771+ template <> struct is_floating_point<sycl::ext::oneapi::bfloat16> : std::true_type {};
1772+ #endif
1773+ template <typename T>
1774+
1775+ inline constexpr bool is_floating_point_v = is_floating_point<T>::value;
1776+
17331777inline unsigned int funnelshift_l(unsigned int low, unsigned int high,
17341778 unsigned int shift);
17351779
@@ -1756,11 +1800,9 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
17561800// cbrt function wrapper.
17571801template <typename ValueT>
17581802inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1759- std::is_same_v<sycl::half, ValueT >,
1803+ std::is_same_v<ValueT, sycl::half>,
17601804 ValueT>
1761- cbrt(ValueT val) {
1762- return sycl::cbrt(static_cast<ValueT >(val));
1763- }
1805+ cbrt(ValueT val);
17641806
17651807// For floating-point types, `float` or `double` arguments are acceptable.
17661808// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
@@ -1798,6 +1840,10 @@ template <typename ValueT, typename ValueU>
17981840inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
17991841fmax_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
18001842
1843+ template <typename ValueT, typename ValueU>
1844+ inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
1845+ fmax_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
1846+
18011847// Performs 2 elements comparison and returns the smaller one. If either of
18021848// inputs is NaN, then return NaN.
18031849template <typename ValueT, typename ValueU>
@@ -1807,6 +1853,10 @@ template <typename ValueT, typename ValueU>
18071853inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
18081854fmin_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
18091855
1856+ template <typename ValueT, typename ValueU>
1857+ inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
1858+ fmin_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
1859+
18101860inline float pow(const float a, const int b) { return sycl::pown(a, b); }
18111861inline double pow(const double a, const int b) { return sycl::pown(a, b); }
18121862
@@ -1867,14 +1917,13 @@ unordered_compare_both(const ValueT a, const ValueT b,
18671917 const BinaryOperation binary_op);
18681918
18691919template <typename ValueT, class BinaryOperation>
1870- inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
1871- const sycl::vec<ValueT, 2> b,
1872- const BinaryOperation binary_op);
1920+ inline std::enable_if_t<ValueT::size() == 2, unsigned>
1921+ compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op);
18731922
18741923template <typename ValueT, class BinaryOperation>
1875- inline unsigned unordered_compare_mask(const sycl::vec <ValueT, 2> a,
1876- const sycl::vec< ValueT, 2> b,
1877- const BinaryOperation binary_op);
1924+ inline std::enable_if_t <ValueT::size() == 2, unsigned>
1925+ unordered_compare_mask(const ValueT a, const ValueT b,
1926+ const BinaryOperation binary_op);
18781927
18791928template <typename S, typename T> inline T vectorized_max(T a, T b);
18801929
@@ -1928,6 +1977,7 @@ inline dot_product_acc_t<T1, T2> dp2a_hi(T1 a, T2 b,
19281977template <typename T1, typename T2>
19291978inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,
19301979 dot_product_acc_t<T1, T2> c);
1980+ } // namespace syclcompat
19311981```
19321982
19331983`vectorized_binary` computes the `BinaryOperation` for two operands,
0 commit comments