Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ template <typename T> auto convertToOpenCLType(T &&x) {
} else if constexpr (std::is_same_v<no_ref, ext::oneapi::bfloat16>) {
// On host, don't interpret BF16 as uint16.
#ifdef __SYCL_DEVICE_ONLY__
using OpenCLType = typename no_ref::Bfloat16StorageT;
return sycl::bit_cast<OpenCLType>(x);
return sycl::bit_cast<uint16_t>(x);
#else
return std::forward<T>(x);
#endif
Expand Down
3 changes: 1 addition & 2 deletions sycl/include/sycl/detail/vector_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,7 @@ vec<convertT, NumElements> vec<DataT, NumElements>::convert() const {
#endif
bool, /*->*/ std::uint8_t, //
sycl::half, /*->*/ sycl::detail::half_impl::StorageT, //
sycl::ext::oneapi::bfloat16,
/*->*/ sycl::ext::oneapi::bfloat16::Bfloat16StorageT, //
sycl::ext::oneapi::bfloat16, /*->*/ uint16_t, //
char, /*->*/ detail::ConvertToOpenCLType_t<char>, //
DataT, /*->*/ DataT //
>::type
Expand Down
22 changes: 13 additions & 9 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ namespace ext::oneapi {

class bfloat16 {
public:
using Bfloat16StorageT = uint16_t;
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
using Bfloat16StorageT
__SYCL_DEPRECATED("bfloat16::Bfloat16StorageT is non-standard and has "
"been deprecated.") = uint16_t;
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We run the risk of removing this as part of the upcoming ABI-break window. Maybe we should have a comment to prevent that from happening? Otherwise there's probably not much point in deprecating it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I say take the risk of breaking someone, because it is unlikely that anyone uses it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doing it mostly for the trunk users, but I do plan to remove it asap (next major release).


bfloat16() = default;
~bfloat16() = default;
Expand Down Expand Up @@ -58,7 +62,7 @@ class bfloat16 {
friend bfloat16 operator-(const bfloat16 &lhs) {
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
(__SYCL_CUDA_ARCH__ >= 800)
Bfloat16StorageT res;
uint16_t res;
asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
return bit_cast<bfloat16>(res);
#else
Expand Down Expand Up @@ -146,18 +150,18 @@ class bfloat16 {
#endif

private:
Bfloat16StorageT value;
uint16_t value;

// Private tag used to avoid constructor ambiguity.
struct private_tag {
explicit private_tag() = default;
};

constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {}
constexpr bfloat16(uint16_t Value, private_tag) : value{Value} {}

// Explicit conversion functions
static float to_float(const Bfloat16StorageT &a);
static Bfloat16StorageT from_float(const float &a);
static float to_float(const uint16_t &a);
static uint16_t from_float(const float &a);

// Friend traits.
friend std::numeric_limits<bfloat16>;
Expand All @@ -178,7 +182,7 @@ class bfloat16 {
extern "C" __DPCPP_SYCL_EXTERNAL float
__devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept;
#endif
inline float bfloat16::to_float(const bfloat16::Bfloat16StorageT &a) {
inline float bfloat16::to_float(const uint16_t &a) {
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
return __devicelib_ConvertBF16ToFINTEL(a);
#else
Expand Down Expand Up @@ -213,11 +217,11 @@ inline uint16_t from_float_to_uint16_t(const float &a) {
extern "C" __DPCPP_SYCL_EXTERNAL uint16_t
__devicelib_ConvertFToBF16INTEL(const float &) noexcept;
#endif
inline bfloat16::Bfloat16StorageT bfloat16::from_float(const float &a) {
inline uint16_t bfloat16::from_float(const float &a) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
#if (__SYCL_CUDA_ARCH__ >= 800)
Bfloat16StorageT res;
uint16_t res;
asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a));
return res;
#else
Expand Down
6 changes: 3 additions & 3 deletions sycl/test-e2e/BFloat16/bfloat_hw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ using get_uint_type_of_size = typename std::conditional_t<
std::conditional_t<Size == 8, uint64_t, void>>>>;

using bfloat16 = sycl::ext::oneapi::bfloat16;
using Bfloat16StorageT = get_uint_type_of_size<sizeof(bfloat16)>;
static_assert(sizeof(bfloat16) == sizeof(uint16_t));

bool test(float Val, Bfloat16StorageT Bits) {
bool test(float Val, uint16_t Bits) {
std::cout << "Value: " << Val << " Bits: " << std::hex << "0x" << Bits
<< std::dec << "...\n";
bool Passed = true;
{
std::cout << " float -> bfloat16 conversion ...";
Bfloat16StorageT RawVal = sycl::bit_cast<Bfloat16StorageT>(bfloat16(Val));
auto RawVal = sycl::bit_cast<uint16_t>(bfloat16(Val));
bool Res = (RawVal == Bits);
Passed &= Res;

Expand Down