Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,21 @@ class bfloat16 {
private:
Bfloat16StorageT value;

// Private tag used to avoid constructor ambiguity.
struct private_tag {
explicit private_tag() = default;
};

constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {}

// Explicit conversion functions
static float to_float(const Bfloat16StorageT &a);
static Bfloat16StorageT from_float(const float &a);

// Friend traits.
friend std::numeric_limits<bfloat16>;
friend std::hash<bfloat16>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any guarantees on the layout/size? If so, can we use bit_cast instead of friendship?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, but bit_cast in our implementation is only sometimes constexpr, so the friendship for numeric_limits lets it be unconditionally constexpr. For std::hash it should be fine to drop the friendship though!


// Friend classes for vector operations
friend class sycl::vec<bfloat16, 1>;
friend class sycl::vec<bfloat16, 2>;
Expand Down Expand Up @@ -615,3 +626,80 @@ inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) {
} // namespace ext::oneapi
} // namespace _V1
} // namespace sycl

// Specialization of some functions in namespace `std`.
namespace std {
Comment on lines +629 to +630
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subjective, but I prefer

// in global ns
template <> struct std::{type}<types...> { ... };

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong preference, but the current style is what we use for hash in other headers, so I would like to keep it as is. Then we can change it at a global level if we want.


// Specialization of `std::hash<sycl::ext::oneapi::bfloat16>`.
template <> struct hash<sycl::ext::oneapi::bfloat16> {
size_t operator()(sycl::ext::oneapi::bfloat16 const &Key) const noexcept {
return hash<uint16_t>{}(Key.value);
}
};

// Specialization of `std::numeric_limits<sycl::ext::oneapi::bfloat16>`.
template <> struct numeric_limits<sycl::ext::oneapi::bfloat16> {
// All following values are calculated based on description of each
// function/value on https://en.cppreference.com/w/cpp/types/numeric_limits.
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
static constexpr bool tinyness_before = false;
static constexpr bool traps = false;
static constexpr int max_exponent10 = 35;
static constexpr int max_exponent = 127;
static constexpr int min_exponent10 = -37;
static constexpr int min_exponent = -126;
static constexpr int radix = 2;
static constexpr int max_digits10 = 4;
static constexpr int digits = 8;
static constexpr bool is_bounded = true;
static constexpr int digits10 = 2;
static constexpr bool is_modulo = false;
static constexpr bool is_iec559 = true;
static constexpr float_round_style round_style = round_to_nearest;

static constexpr const sycl::ext::oneapi::bfloat16(min)() noexcept {
return {uint16_t(0x80), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16(max)() noexcept {
return {uint16_t(0x7f7f), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 lowest() noexcept {
return {uint16_t(0xff7f), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 epsilon() noexcept {
return {uint16_t(0x3c00), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 round_error() noexcept {
return {uint16_t(0x3f00), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 infinity() noexcept {
return {uint16_t(0x7f80), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 quiet_NaN() noexcept {
return {uint16_t(0x7fc0), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 signaling_NaN() noexcept {
return {uint16_t(0xff81), sycl::ext::oneapi::bfloat16::private_tag{}};
}

static constexpr const sycl::ext::oneapi::bfloat16 denorm_min() noexcept {
return {uint16_t(0x1), sycl::ext::oneapi::bfloat16::private_tag{}};
}
};

} // namespace std
103 changes: 103 additions & 0 deletions sycl/test-e2e/BFloat16/bfloat16_limits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/// Checks a numeric_limits specialization of bfloat16.

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <sycl/detail/core.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/ext/oneapi/experimental/bfloat16_math.hpp>

namespace sycl_ext = sycl::ext::oneapi;

using Limit = std::numeric_limits<sycl_ext::bfloat16>;

// Result of std::log10(2).
constexpr float Log10_2 = 0.30103f;

// Helper constexpr ceil function.
constexpr int constexpr_ceil(float Val) {
return Val + (float(int(Val)) == Val ? 0.f : 1.f);
}

int Check(bool Condition, sycl_ext::bfloat16 Value, std::string CheckName) {
if (!Condition)
std::cout << "Failed " << CheckName << " for " << Value << std::endl;
return !Condition;
}

int CheckBfloat16(uint16_t Sign, uint16_t Exponent, uint16_t Significand) {
const auto Value = sycl::bit_cast<sycl_ext::bfloat16>(
uint16_t((Sign << 15) | (Exponent << 7) | Significand));

int Failed = 0;

Failed += Check(Limit::lowest() <= Value, Value, "lowest()");
Failed += Check(Limit::max() >= Value, Value, "max()");

// min() is the lowest normal number, so if Value is negative, 0 or a
// subnormal - the latter two being represented by a 0-exponent - min() must
// be strictly greater.
if (Sign || Exponent == 0x0)
Failed += Check(Limit::min() > Value, Value, "min() (1)");
else
Failed += Check(Limit::min() <= Value, Value, "min() (2)");

// denorm_min() is the lowest subnormal number, so if Value is negative or 0
// denorm_min() must be strictly greater.
if (Sign || (Exponent == 0x0 && Significand == 0x0))
Failed += Check(Limit::denorm_min() > Value, Value, "denorm_min() (1)");
else
Failed += Check(Limit::denorm_min() <= Value, Value, "denorm_min() (2)");

return Failed;
}

int main() {
static_assert(Limit::is_specialized);
static_assert(Limit::is_signed);
static_assert(!Limit::is_integer);
static_assert(!Limit::is_exact);
static_assert(Limit::has_infinity);
static_assert(Limit::has_quiet_NaN);
static_assert(Limit::has_signaling_NaN);
static_assert(Limit::has_denorm == std::float_denorm_style::denorm_present);
static_assert(!Limit::has_denorm_loss);
static_assert(!Limit::tinyness_before);
static_assert(!Limit::traps);
static_assert(Limit::max_exponent10 == 35);
static_assert(Limit::max_exponent == 127);
static_assert(Limit::min_exponent10 == -37);
static_assert(Limit::min_exponent == -126);
static_assert(Limit::radix == 2);
static_assert(Limit::digits == 8);
static_assert(Limit::max_digits10 ==
constexpr_ceil(float(Limit::digits) * Log10_2 + 1.0f));
static_assert(Limit::is_bounded);
static_assert(Limit::digits10 == int(Limit::digits * Log10_2));
static_assert(!Limit::is_modulo);
static_assert(Limit::is_iec559);
static_assert(Limit::round_style == std::float_round_style::round_to_nearest);

int Failed = 0;

Failed += Check(sycl_ext::experimental::isnan(Limit::quiet_NaN()),
Limit::quiet_NaN(), "quiet_NaN()");
Failed += Check(sycl_ext::experimental::isnan(Limit::signaling_NaN()),
Limit::signaling_NaN(), "signaling_NaN()");
// isinf does not exist for bfloat16 currently.
Failed += Check(Limit::infinity() ==
sycl::bit_cast<sycl_ext::bfloat16>(uint16_t(0xff << 7)),
Limit::infinity(), "infinity()");
Failed += Check(Limit::round_error() == sycl_ext::bfloat16(0.5f),
Limit::round_error(), "round_error()");
Failed += Check(sycl_ext::bfloat16{1.0f} + Limit::epsilon() >
sycl_ext::bfloat16{1.0f},
Limit::epsilon(), "epsilon()");

for (uint16_t Sign : {0, 1})
for (uint16_t Exponent = 0; Exponent < 0xff; ++Exponent)
for (uint16_t Significand = 0; Significand < 0x7f; ++Significand)
Failed += CheckBfloat16(Sign, Exponent, Significand);

return Failed;
}
Loading