-
Notifications
You must be signed in to change notification settings - Fork 795
[SYCL][COMPAT] Add bfloat16 support to several maths ops #15572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
bfloat16 can't be constexpr constructed
This includes sycl::half and sycl::ext::oneapi::bfloat16
- Adding bfloat to type lists - Making test fixtures work with sycl::vec<type,N> and sycl::marray<type,N>
No new functionality here
Mixed type cases aren't caught otherwise
- Added support for bfloat16 - Added tests for `sycl::vec<T,2>` & `sycl::marray<T,2>`
& a bit of tidying
& more tests
marray version never tested!
Function casts to sycl::complex<float> (no native bfloat16 support)
i.e. sycl::marray<bfloat16, N>, sycl::vec<bfloat16, N>
Also reimplement `fmax_nan` & `fmin_nan` in terms of improved max/min
- compare_mask & unordered_compare_mask support sycl::marray - sycl::marray tests for all `compare` APIs
This reverts commit 55e95ad.
| return sycl::ext::oneapi::experimental::isnan(a); | ||
| if constexpr (detail::support_bfloat16_math && | ||
| std::is_same_v<ValueT, sycl::ext::oneapi::bfloat16>) { | ||
| return sycl::ext::oneapi::experimental::isnan(a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not supported, how about throwing an exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a static_assert since we'll know at compile time 👍
sycl/include/syclcompat/math.hpp
Outdated
| min(ValueT a, ValueU b) { | ||
| return sycl::fmin(static_cast<std::common_type_t<ValueT, ValueU>>(a), | ||
| static_cast<std::common_type_t<ValueT, ValueU>>(b)); | ||
| if constexpr (detail::support_bfloat16_math && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about throwing an exception if bfloat16 not supported
sycl/include/syclcompat/math.hpp
Outdated
| max(ValueT a, ValueU b) { | ||
| return sycl::fmax(static_cast<std::common_type_t<ValueT, ValueU>>(a), | ||
| static_cast<std::common_type_t<ValueT, ValueU>>(b)); | ||
| if constexpr (detail::support_bfloat16_math && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
sycl/include/syclcompat/math.hpp
Outdated
| }; | ||
|
|
||
| template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> { | ||
| using type = sycl::ext::oneapi::bfloat16; // std::common_type_t<float, T>; // |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the comments. std::commone_type_t<float, sycl::ext::oneapi::bfloat16> will be bfloat16, and won't be float, why do the comments say bfloat16 will be promoted to float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woops! Thanks for catching that. Stale comment from earlier impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for tackling this @joeatodd , It´ s quite a good PR
sycl/include/syclcompat/math.hpp
Outdated
| inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT> | ||
| cbrt(ValueT val) { | ||
| return sycl::cbrt(static_cast<ValueT>(val)); | ||
| if constexpr (std::is_same_v<sycl::ext::oneapi::bfloat16, ValueT>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to not check for detail::support_bfloat16_math here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, since we're casting to float & doing cbrt<float>, it's not needed. But I'm not confident casting is the right approach actually... The bfloat16 extension doesn't have cbrt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reverted the cbrt change in this PR. The math op isn't currently implemented by the bfloat16 math extension, and it wasn't actually part of the request from DCPT.
and remove stale comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sycl/include/syclcompat/math.hpp
Outdated
| inline std::enable_if_t<syclcompat::is_floating_point_v<ValueT>, ValueT> | ||
| cbrt(ValueT val) { | ||
| return sycl::cbrt(static_cast<ValueT>(val)); | ||
| if constexpr (std::is_same_v<sycl::ext::oneapi::bfloat16, ValueT>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, since we're casting to float & doing cbrt<float>, it's not needed. But I'm not confident casting is the right approach actually... The bfloat16 extension doesn't have cbrt.
sycl/include/syclcompat/math.hpp
Outdated
| }; | ||
|
|
||
| template <typename T> struct common_type<sycl::ext::oneapi::bfloat16, T> { | ||
| using type = sycl::ext::oneapi::bfloat16; // std::common_type_t<float, T>; // |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woops! Thanks for catching that. Stale comment from earlier impl.
| return sycl::ext::oneapi::experimental::isnan(a); | ||
| if constexpr (detail::support_bfloat16_math && | ||
| std::is_same_v<ValueT, sycl::ext::oneapi::bfloat16>) { | ||
| return sycl::ext::oneapi::experimental::isnan(a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a static_assert since we'll know at compile time 👍
|
@intel/llvm-gatekeepers this is ready to merge 🙏 |
Adds support for
sycl::ext::oneapi::bfloat16to:reluclampfmax_nanfmin_nanminmaxcompare_maskunordered_compare_mask