@@ -1048,6 +1048,10 @@ static inline unsigned int get_device_id(const sycl::device &dev);
10481048// Util function to get the number of available devices
10491049static inline unsigned int device_count();
10501050
1051+ // Util function to check whether a device supports some kinds of sycl::aspect.
1052+ static inline void
1053+ has_capability_or_fail(const sycl::device &dev,
1054+ const std::initializer_list<sycl::aspect> &props);
10511055} // syclcompat
10521056```
10531057
@@ -1725,7 +1729,51 @@ second operand, respectively. These three APIs return a single 32-bit value with
17251729the accumulated result, which is unsigned if both operands are `uint32_t ` and
17261730signed otherwise.
17271731
1732+ Various maths functions are defined operate on any floating point types.
1733+ `syclcompat::is_floating_point_v` extends the standard library' s
1734+ `std::is_floating_point_v` to include `sycl::half` and, where available,
1735+ `sycl::ext::oneapi::bfloat16`. The current version of SYCLcompat also provides
1736+ a specialization of `std::common_type_t` for `sycl::ext::oneapi::bfloat16`,
1737+ though this will be moved to the `sycl_ext_oneapi_bfloat16` extension in
1738+ future.
1739+
1740+ ```cpp
1741+ namespace std {
1742+ template <> struct common_type<sycl::ext::oneapi::bfloat16> {
1743+ using type = sycl::ext::oneapi::bfloat16;
1744+ };
1745+
1746+ template <>
1747+ struct common_type<sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16> {
1748+ using type = sycl::ext::oneapi::bfloat16;
1749+ };
1750+
1751+ template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
1752+ using type = sycl::ext::oneapi::bfloat16;
1753+ };
1754+
1755+ template <typename T> struct common_type<T, sycl::ext::oneapi::bfloat16> {
1756+ using type = sycl::ext::oneapi::bfloat16;
1757+ };
1758+ } // namespace std
1759+ ```
1760+
17281761```cpp
1762+ namespace syclcompat{
1763+
1764+ // Trait for extended floating point definition
1765+ template <typename T>
1766+ struct is_floating_point : std::is_floating_point<T>{};
1767+
1768+ template <> struct is_floating_point<sycl::half> : std::true_type {};
1769+
1770+ #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
1771+ template <> struct is_floating_point<sycl::ext::oneapi::bfloat16> : std::true_type {};
1772+ #endif
1773+ template <typename T>
1774+
1775+ inline constexpr bool is_floating_point_v = is_floating_point<T>::value;
1776+
17291777inline unsigned int funnelshift_l(unsigned int low, unsigned int high,
17301778 unsigned int shift);
17311779
@@ -1752,11 +1800,9 @@ inline std::enable_if_t<ValueT::size() == 2, ValueT> isnan(const ValueT a);
17521800// cbrt function wrapper.
17531801template <typename ValueT>
17541802inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
1755- std::is_same_v<sycl::half, ValueT >,
1803+ std::is_same_v<ValueT, sycl::half>,
17561804 ValueT>
1757- cbrt(ValueT val) {
1758- return sycl::cbrt(static_cast<ValueT >(val));
1759- }
1805+ cbrt(ValueT val);
17601806
17611807// For floating-point types, `float` or `double` arguments are acceptable.
17621808// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
@@ -1794,6 +1840,10 @@ template <typename ValueT, typename ValueU>
17941840inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
17951841fmax_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
17961842
1843+ template <typename ValueT, typename ValueU>
1844+ inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
1845+ fmax_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
1846+
17971847// Performs 2 elements comparison and returns the smaller one. If either of
17981848// inputs is NaN, then return NaN.
17991849template <typename ValueT, typename ValueU>
@@ -1803,6 +1853,10 @@ template <typename ValueT, typename ValueU>
18031853inline sycl::vec<std::common_type_t<ValueT, ValueU>, 2>
18041854fmin_nan(const sycl::vec<ValueT, 2> a, const sycl::vec<ValueU, 2> b);
18051855
1856+ template <typename ValueT, typename ValueU>
1857+ inline sycl::marray<std::common_type_t<ValueT, ValueU>, 2>
1858+ fmin_nan(const sycl::marray<ValueT, 2> a, const sycl::marray<ValueU, 2> b);
1859+
18061860inline float pow(const float a, const int b) { return sycl::pown(a, b); }
18071861inline double pow(const double a, const int b) { return sycl::pown(a, b); }
18081862
@@ -1863,14 +1917,13 @@ unordered_compare_both(const ValueT a, const ValueT b,
18631917 const BinaryOperation binary_op);
18641918
18651919template <typename ValueT, class BinaryOperation>
1866- inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
1867- const sycl::vec<ValueT, 2> b,
1868- const BinaryOperation binary_op);
1920+ inline std::enable_if_t<ValueT::size() == 2, unsigned>
1921+ compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op);
18691922
18701923template <typename ValueT, class BinaryOperation>
1871- inline unsigned unordered_compare_mask(const sycl::vec <ValueT, 2> a,
1872- const sycl::vec< ValueT, 2> b,
1873- const BinaryOperation binary_op);
1924+ inline std::enable_if_t <ValueT::size() == 2, unsigned>
1925+ unordered_compare_mask(const ValueT a, const ValueT b,
1926+ const BinaryOperation binary_op);
18741927
18751928template <typename S, typename T> inline T vectorized_max(T a, T b);
18761929
@@ -1924,6 +1977,7 @@ inline dot_product_acc_t<T1, T2> dp2a_hi(T1 a, T2 b,
19241977template <typename T1, typename T2>
19251978inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,
19261979 dot_product_acc_t<T1, T2> c);
1980+ } // namespace syclcompat
19271981```
19281982
19291983`vectorized_binary` computes the `BinaryOperation` for two operands,
0 commit comments