-
Notifications
You must be signed in to change notification settings - Fork 790
[SYCL][Docs] Add std::hash and std::numeric_limits specialization for bfloat16 #19838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
a3ca737
8a7ad74
ca1c081
ed5a30e
8f94fb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,10 +141,21 @@ class bfloat16 { | |
private: | ||
Bfloat16StorageT value; | ||
|
||
// Private tag used to avoid constructor ambiguity. | ||
struct private_tag { | ||
explicit private_tag() = default; | ||
}; | ||
|
||
constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {} | ||
|
||
// Explicit conversion functions | ||
static float to_float(const Bfloat16StorageT &a); | ||
static Bfloat16StorageT from_float(const float &a); | ||
|
||
// Friend traits. | ||
friend std::numeric_limits<bfloat16>; | ||
friend std::hash<bfloat16>; | ||
|
||
// Friend classes for vector operations | ||
friend class sycl::vec<bfloat16, 1>; | ||
friend class sycl::vec<bfloat16, 2>; | ||
|
@@ -615,3 +626,80 @@ inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) { | |
} // namespace ext::oneapi | ||
} // namespace _V1 | ||
} // namespace sycl | ||
|
||
// Specialization of some functions in namespace `std`. | ||
namespace std { | ||
Comment on lines
+629
to
+630
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Subjective, but I prefer
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong preference, but the current style is what we use for |
||
|
||
// Specialization of `std::hash<sycl::ext::oneapi::bfloat16>`. | ||
template <> struct hash<sycl::ext::oneapi::bfloat16> { | ||
size_t operator()(sycl::ext::oneapi::bfloat16 const &Key) const noexcept { | ||
return hash<uint16_t>{}(Key.value); | ||
} | ||
}; | ||
|
||
// Specialization of `std::numeric_limits<sycl::ext::oneapi::bfloat16>`. | ||
template <> struct numeric_limits<sycl::ext::oneapi::bfloat16> { | ||
// All following values are calculated based on description of each | ||
// function/value on https://en.cppreference.com/w/cpp/types/numeric_limits. | ||
static constexpr bool is_specialized = true; | ||
static constexpr bool is_signed = true; | ||
static constexpr bool is_integer = false; | ||
static constexpr bool is_exact = false; | ||
static constexpr bool has_infinity = true; | ||
static constexpr bool has_quiet_NaN = true; | ||
static constexpr bool has_signaling_NaN = true; | ||
static constexpr float_denorm_style has_denorm = denorm_present; | ||
static constexpr bool has_denorm_loss = false; | ||
static constexpr bool tinyness_before = false; | ||
static constexpr bool traps = false; | ||
static constexpr int max_exponent10 = 35; | ||
static constexpr int max_exponent = 127; | ||
static constexpr int min_exponent10 = -37; | ||
static constexpr int min_exponent = -126; | ||
static constexpr int radix = 2; | ||
static constexpr int max_digits10 = 4; | ||
static constexpr int digits = 8; | ||
static constexpr bool is_bounded = true; | ||
static constexpr int digits10 = 2; | ||
static constexpr bool is_modulo = false; | ||
static constexpr bool is_iec559 = true; | ||
static constexpr float_round_style round_style = round_to_nearest; | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16(min)() noexcept { | ||
return {uint16_t(0x80), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16(max)() noexcept { | ||
return {uint16_t(0x7f7f), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 lowest() noexcept { | ||
return {uint16_t(0xff7f), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 epsilon() noexcept { | ||
return {uint16_t(0x3c00), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 round_error() noexcept { | ||
return {uint16_t(0x3f00), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 infinity() noexcept { | ||
return {uint16_t(0x7f80), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 quiet_NaN() noexcept { | ||
return {uint16_t(0x7fc0), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 signaling_NaN() noexcept { | ||
return {uint16_t(0xff81), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
|
||
static constexpr const sycl::ext::oneapi::bfloat16 denorm_min() noexcept { | ||
return {uint16_t(0x1), sycl::ext::oneapi::bfloat16::private_tag{}}; | ||
} | ||
}; | ||
|
||
} // namespace std |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/// Checks a numeric_limits specialization of bfloat16. | ||
|
||
// RUN: %{build} -o %t.out | ||
// RUN: %{run} %t.out | ||
steffenlarsen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
#include <sycl/detail/core.hpp> | ||
#include <sycl/ext/oneapi/bfloat16.hpp> | ||
#include <sycl/ext/oneapi/experimental/bfloat16_math.hpp> | ||
|
||
namespace sycl_ext = sycl::ext::oneapi; | ||
|
||
using Limit = std::numeric_limits<sycl_ext::bfloat16>; | ||
|
||
// Result of std::log10(2). | ||
constexpr float Log10_2 = 0.30103f; | ||
|
||
// Helper constexpr ceil function. | ||
constexpr int constexpr_ceil(float Val) { | ||
return Val + (float(int(Val)) == Val ? 0.f : 1.f); | ||
} | ||
|
||
int Check(bool Condition, sycl_ext::bfloat16 Value, std::string CheckName) { | ||
if (!Condition) | ||
std::cout << "Failed " << CheckName << " for " << Value << std::endl; | ||
return !Condition; | ||
} | ||
|
||
int CheckBfloat16(uint16_t Sign, uint16_t Exponent, uint16_t Significand) { | ||
const auto Value = sycl::bit_cast<sycl_ext::bfloat16>( | ||
uint16_t((Sign << 15) | (Exponent << 7) | Significand)); | ||
|
||
int Failed = 0; | ||
|
||
Failed += Check(Limit::lowest() <= Value, Value, "lowest()"); | ||
Failed += Check(Limit::max() >= Value, Value, "max()"); | ||
|
||
// min() is the lowest normal number, so if Value is negative, 0 or a | ||
// subnormal - the latter two being represented by a 0-exponent - min() must | ||
// be strictly greater. | ||
if (Sign || Exponent == 0x0) | ||
Failed += Check(Limit::min() > Value, Value, "min() (1)"); | ||
else | ||
Failed += Check(Limit::min() <= Value, Value, "min() (2)"); | ||
|
||
// denorm_min() is the lowest subnormal number, so if Value is negative or 0 | ||
// denorm_min() must be strictly greater. | ||
if (Sign || (Exponent == 0x0 && Significand == 0x0)) | ||
Failed += Check(Limit::denorm_min() > Value, Value, "denorm_min() (1)"); | ||
else | ||
Failed += Check(Limit::denorm_min() <= Value, Value, "denorm_min() (2)"); | ||
|
||
return Failed; | ||
} | ||
|
||
int main() { | ||
static_assert(Limit::is_specialized); | ||
static_assert(Limit::is_signed); | ||
static_assert(!Limit::is_integer); | ||
static_assert(!Limit::is_exact); | ||
static_assert(Limit::has_infinity); | ||
static_assert(Limit::has_quiet_NaN); | ||
static_assert(Limit::has_signaling_NaN); | ||
static_assert(Limit::has_denorm == std::float_denorm_style::denorm_present); | ||
static_assert(!Limit::has_denorm_loss); | ||
static_assert(!Limit::tinyness_before); | ||
static_assert(!Limit::traps); | ||
static_assert(Limit::max_exponent10 == 35); | ||
static_assert(Limit::max_exponent == 127); | ||
static_assert(Limit::min_exponent10 == -37); | ||
static_assert(Limit::min_exponent == -126); | ||
static_assert(Limit::radix == 2); | ||
static_assert(Limit::digits == 8); | ||
static_assert(Limit::max_digits10 == | ||
constexpr_ceil(float(Limit::digits) * Log10_2 + 1.0f)); | ||
static_assert(Limit::is_bounded); | ||
static_assert(Limit::digits10 == int(Limit::digits * Log10_2)); | ||
static_assert(!Limit::is_modulo); | ||
static_assert(Limit::is_iec559); | ||
static_assert(Limit::round_style == std::float_round_style::round_to_nearest); | ||
|
||
int Failed = 0; | ||
|
||
Failed += Check(sycl_ext::experimental::isnan(Limit::quiet_NaN()), | ||
Limit::quiet_NaN(), "quiet_NaN()"); | ||
Failed += Check(sycl_ext::experimental::isnan(Limit::signaling_NaN()), | ||
Limit::signaling_NaN(), "signaling_NaN()"); | ||
// isinf does not exist for bfloat16 currently. | ||
Failed += Check(Limit::infinity() == | ||
sycl::bit_cast<sycl_ext::bfloat16>(uint16_t(0xff << 7)), | ||
Limit::infinity(), "infinity()"); | ||
Failed += Check(Limit::round_error() == sycl_ext::bfloat16(0.5f), | ||
Limit::round_error(), "round_error()"); | ||
Failed += Check(sycl_ext::bfloat16{1.0f} + Limit::epsilon() > | ||
sycl_ext::bfloat16{1.0f}, | ||
Limit::epsilon(), "epsilon()"); | ||
|
||
for (uint16_t Sign : {0, 1}) | ||
for (uint16_t Exponent = 0; Exponent < 0xff; ++Exponent) | ||
for (uint16_t Significand = 0; Significand < 0x7f; ++Significand) | ||
Failed += CheckBfloat16(Sign, Exponent, Significand); | ||
|
||
return Failed; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any guarantees on the layout/size? If so, can we use
bit_cast
instead of friendship?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do, but
bit_cast
in our implementation is only sometimesconstexpr
, so the friendship fornumeric_limits
lets it be unconditionallyconstexpr
. Forstd::hash
it should be fine to drop the friendship though!