Skip to content

Conversation

@joeatodd
Copy link
Contributor

@joeatodd joeatodd commented Oct 2, 2024

Adds support for sycl::ext::oneapi::bfloat16 to:

  • relu
  • clamp
  • fmax_nan
  • fmin_nan
  • min
  • max
  • compare_mask
  • unordered_compare_mask

bfloat16 can't be constexpr constructed
This includes sycl::half and sycl::ext::oneapi::bfloat16
- Adding bfloat to type lists
- Making test fixtures work with sycl::vec<type,N> and
sycl::marray<type,N>
No new functionality here
Mixed type cases aren't caught otherwise
- Added support for bfloat16
- Added tests for `sycl::vec<T,2>` & `sycl::marray<T,2>`
marray version never tested!
Function casts to sycl::complex<float> (no native bfloat16 support)
i.e. sycl::marray<bfloat16, N>, sycl::vec<bfloat16, N>
Also reimplement `fmax_nan` & `fmin_nan` in terms of improved max/min
- compare_mask & unordered_compare_mask support sycl::marray
- sycl::marray tests for all `compare` APIs
@joeatodd joeatodd requested a review from a team as a code owner October 2, 2024 11:28
return sycl::ext::oneapi::experimental::isnan(a);
if constexpr (detail::support_bfloat16_math &&
std::is_same_v<ValueT, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::isnan(a);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not supported, how about throwing an exception

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a static_assert since we'll know at compile time 👍

min(ValueT a, ValueU b) {
return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a),
static_cast<std::common_type_t<ValueT, ValueU>>(b));
if constexpr (detail::support_bfloat16_math &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about throwing an exception if bfloat16 not supported

max(ValueT a, ValueU b) {
return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a),
static_cast<std::common_type_t<ValueT, ValueU>>(b));
if constexpr (detail::support_bfloat16_math &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

};

template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
using type = sycl::ext::oneapi::bfloat16; // std::common_type_t<float, T>; //
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the comments. std::commone_type_t<float, sycl::ext::oneapi::bfloat16> will be bfloat16, and won't be float, why do the comments say bfloat16 will be promoted to float?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops! Thanks for catching that. Stale comment from earlier impl.

Copy link
Contributor

@Alcpz Alcpz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for tackling this @joeatodd , It´ s quite a good PR

inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
cbrt(ValueT val) {
return sycl::cbrt(static_cast<ValueT>(val));
if constexpr (std::is_same_v<sycl::ext::oneapi::bfloat16, ValueT>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not check for detail::support_bfloat16_math here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, since we're casting to float & doing cbrt<float>, it's not needed. But I'm not confident casting is the right approach actually... The bfloat16 extension doesn't have cbrt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reverted the cbrt change in this PR. The math op isn't currently implemented by the bfloat16 math extension, and it wasn't actually part of the request from DCPT.

Copy link
Contributor Author

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review comments addressed. Thanks @Alcpz and @ziranzha for your input 🙏

inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT>
cbrt(ValueT val) {
return sycl::cbrt(static_cast<ValueT>(val));
if constexpr (std::is_same_v<sycl::ext::oneapi::bfloat16, ValueT>) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, since we're casting to float & doing cbrt<float>, it's not needed. But I'm not confident casting is the right approach actually... The bfloat16 extension doesn't have cbrt.

};

template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> {
using type = sycl::ext::oneapi::bfloat16; // std::common_type_t<float, T>; //
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops! Thanks for catching that. Stale comment from earlier impl.

return sycl::ext::oneapi::experimental::isnan(a);
if constexpr (detail::support_bfloat16_math &&
std::is_same_v<ValueT, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::isnan(a);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a static_assert since we'll know at compile time 👍

@joeatodd
Copy link
Contributor Author

@intel/llvm-gatekeepers this is ready to merge 🙏

@martygrant martygrant merged commit 1791115 into intel:sycl Oct 17, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants