Skip to content

Commit a3ca737

Browse files
committed
[SYCL] Add std::hash and std::numeric_limits specialization for bfloat16
This commit adds the missing std::hash and std::numeric_limits specializations for the sycl::ext::oneapi::bfloat16 class. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 421bfc1 commit a3ca737

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,21 @@ class bfloat16 {
141141
private:
142142
Bfloat16StorageT value;
143143

144+
// Private tag used to avoid constructor ambiguity.
145+
struct private_tag {
146+
explicit private_tag() = default;
147+
};
148+
149+
constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {}
150+
144151
// Explicit conversion functions
145152
static float to_float(const Bfloat16StorageT &a);
146153
static Bfloat16StorageT from_float(const float &a);
147154

155+
// Friend traits.
156+
friend std::numeric_limits<bfloat16>;
157+
friend std::hash<bfloat16>;
158+
148159
// Friend classes for vector operations
149160
friend class sycl::vec<bfloat16, 1>;
150161
friend class sycl::vec<bfloat16, 2>;
@@ -615,3 +626,80 @@ inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) {
615626
} // namespace ext::oneapi
616627
} // namespace _V1
617628
} // namespace sycl
629+
630+
// Specialization of some functions in namespace `std`.
631+
namespace std {
632+
633+
// Specialization of `std::hash<sycl::ext::oneapi::bfloat16>`.
634+
template <> struct hash<sycl::ext::oneapi::bfloat16> {
635+
size_t operator()(sycl::ext::oneapi::bfloat16 const &Key) const noexcept {
636+
return hash<uint16_t>{}(Key.value);
637+
}
638+
};
639+
640+
// Specialization of `std::numeric_limits<sycl::ext::oneapi::bfloat16>`.
641+
template <> struct numeric_limits<sycl::ext::oneapi::bfloat16> {
642+
// All following values are calculated based on description of each
643+
// function/value on https://en.cppreference.com/w/cpp/types/numeric_limits.
644+
static constexpr bool is_specialized = true;
645+
static constexpr bool is_signed = true;
646+
static constexpr bool is_integer = false;
647+
static constexpr bool is_exact = false;
648+
static constexpr bool has_infinity = true;
649+
static constexpr bool has_quiet_NaN = true;
650+
static constexpr bool has_signaling_NaN = true;
651+
static constexpr float_denorm_style has_denorm = denorm_present;
652+
static constexpr bool has_denorm_loss = false;
653+
static constexpr bool tinyness_before = false;
654+
static constexpr bool traps = false;
655+
static constexpr int max_exponent10 = 35;
656+
static constexpr int max_exponent = 127;
657+
static constexpr int min_exponent10 = -37;
658+
static constexpr int min_exponent = -126;
659+
static constexpr int radix = 2;
660+
static constexpr int max_digits10 = 4;
661+
static constexpr int digits = 8;
662+
static constexpr bool is_bounded = true;
663+
static constexpr int digits10 = 2;
664+
static constexpr bool is_modulo = false;
665+
static constexpr bool is_iec559 = true;
666+
static constexpr float_round_style round_style = round_to_nearest;
667+
668+
static constexpr const sycl::ext::oneapi::bfloat16(min)() noexcept {
669+
return {uint16_t(0x80), sycl::ext::oneapi::bfloat16::private_tag{}};
670+
}
671+
672+
static constexpr const sycl::ext::oneapi::bfloat16(max)() noexcept {
673+
return {uint16_t(0x7f7f), sycl::ext::oneapi::bfloat16::private_tag{}};
674+
}
675+
676+
static constexpr const sycl::ext::oneapi::bfloat16 lowest() noexcept {
677+
return {uint16_t(0xff7f), sycl::ext::oneapi::bfloat16::private_tag{}};
678+
}
679+
680+
static constexpr const sycl::ext::oneapi::bfloat16 epsilon() noexcept {
681+
return {uint16_t(0x3c00), sycl::ext::oneapi::bfloat16::private_tag{}};
682+
}
683+
684+
static constexpr const sycl::ext::oneapi::bfloat16 round_error() noexcept {
685+
return {uint16_t(0x3f00), sycl::ext::oneapi::bfloat16::private_tag{}};
686+
}
687+
688+
static constexpr const sycl::ext::oneapi::bfloat16 infinity() noexcept {
689+
return {uint16_t(0x7f80), sycl::ext::oneapi::bfloat16::private_tag{}};
690+
}
691+
692+
static constexpr const sycl::ext::oneapi::bfloat16 quiet_NaN() noexcept {
693+
return {uint16_t(0x7fc0), sycl::ext::oneapi::bfloat16::private_tag{}};
694+
}
695+
696+
static constexpr const sycl::ext::oneapi::bfloat16 signaling_NaN() noexcept {
697+
return {uint16_t(0xff81), sycl::ext::oneapi::bfloat16::private_tag{}};
698+
}
699+
700+
static constexpr const sycl::ext::oneapi::bfloat16 denorm_min() noexcept {
701+
return {uint16_t(0x1), sycl::ext::oneapi::bfloat16::private_tag{}};
702+
}
703+
};
704+
705+
} // namespace std
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/// Checks a numeric_limits specialization of bfloat16.
2+
3+
// RUN: %{build} -o %t.out
4+
// RUN: %{run} %t.out
5+
6+
#include <sycl/detail/core.hpp>
7+
#include <sycl/ext/oneapi/bfloat16.hpp>
8+
#include <sycl/ext/oneapi/experimental/bfloat16_math.hpp>
9+
10+
namespace sycl_ext = sycl::ext::oneapi;
11+
12+
using Limit = std::numeric_limits<sycl_ext::bfloat16>;
13+
14+
// Result of std::log10(2).
15+
constexpr float Log10_2 = 0.30103f;
16+
17+
// Helper constexpr ceil function.
18+
constexpr int ceil(float Val) {
19+
return Val + (float(int(Val)) == Val ? 0.f : 1.f);
20+
}
21+
22+
int Check(bool Condition, sycl_ext::bfloat16 Value, std::string CheckName) {
23+
if (!Condition)
24+
std::cout << "Failed " << CheckName << " for " << Value << std::endl;
25+
return !Condition;
26+
}
27+
28+
int CheckBfloat16(uint16_t Sign, uint16_t Exponent, uint16_t Significand) {
29+
const auto Value = sycl::bit_cast<sycl_ext::bfloat16>(
30+
uint16_t((Sign << 15) | (Exponent << 7) | Significand));
31+
32+
int Failed = 0;
33+
34+
Failed += Check(Limit::lowest() <= Value, Value, "lowest()");
35+
Failed += Check(Limit::max() >= Value, Value, "max()");
36+
37+
// min() is the lowest normal number, so if Value is negative, 0 or a
38+
// subnormal - the latter two being represented by a 0-exponent - min() must
39+
// be strictly greater.
40+
if (Sign || Exponent == 0x0)
41+
Failed += Check(Limit::min() > Value, Value, "min() (1)");
42+
else
43+
Failed += Check(Limit::min() <= Value, Value, "min() (2)");
44+
45+
// denorm_min() is the lowest subnormal number, so if Value is negative or 0
46+
// denorm_min() must be strictly greater.
47+
if (Sign || (Exponent == 0x0 && Significand == 0x0))
48+
Failed += Check(Limit::denorm_min() > Value, Value, "denorm_min() (1)");
49+
else
50+
Failed += Check(Limit::denorm_min() <= Value, Value, "denorm_min() (2)");
51+
52+
return Failed;
53+
}
54+
55+
int main() {
56+
static_assert(Limit::is_specialized);
57+
static_assert(Limit::is_signed);
58+
static_assert(!Limit::is_integer);
59+
static_assert(!Limit::is_exact);
60+
static_assert(Limit::has_infinity);
61+
static_assert(Limit::has_quiet_NaN);
62+
static_assert(Limit::has_signaling_NaN);
63+
static_assert(Limit::has_denorm == std::float_denorm_style::denorm_present);
64+
static_assert(!Limit::has_denorm_loss);
65+
static_assert(!Limit::tinyness_before);
66+
static_assert(!Limit::traps);
67+
static_assert(Limit::max_exponent10 == 35);
68+
static_assert(Limit::max_exponent == 127);
69+
static_assert(Limit::min_exponent10 == -37);
70+
static_assert(Limit::min_exponent == -126);
71+
static_assert(Limit::radix == 2);
72+
static_assert(Limit::digits == 8);
73+
static_assert(Limit::max_digits10 ==
74+
ceil(float(Limit::digits) * Log10_2 + 1.0f));
75+
static_assert(Limit::is_bounded);
76+
static_assert(Limit::digits10 == int(Limit::digits * Log10_2));
77+
static_assert(!Limit::is_modulo);
78+
static_assert(Limit::is_iec559);
79+
static_assert(Limit::round_style == std::float_round_style::round_to_nearest);
80+
81+
int Failed = 0;
82+
83+
Failed += Check(sycl_ext::experimental::isnan(Limit::quiet_NaN()),
84+
Limit::quiet_NaN(), "quiet_NaN()");
85+
Failed += Check(sycl_ext::experimental::isnan(Limit::signaling_NaN()),
86+
Limit::signaling_NaN(), "signaling_NaN()");
87+
// isinf does not exist for bfloat16 currently.
88+
Failed += Check(Limit::infinity() ==
89+
sycl::bit_cast<sycl_ext::bfloat16>(uint16_t(0xff << 7)),
90+
Limit::infinity(), "infinity()");
91+
Failed += Check(Limit::round_error() == sycl_ext::bfloat16(0.5f),
92+
Limit::round_error(), "round_error()");
93+
Failed += Check(sycl_ext::bfloat16{1.0f} + Limit::epsilon() >
94+
sycl_ext::bfloat16{1.0f},
95+
Limit::epsilon(), "epsilon()");
96+
97+
for (uint16_t Sign : {0, 1})
98+
for (uint16_t Exponent = 0; Exponent < 0xff; ++Exponent)
99+
for (uint16_t Significand = 0; Significand < 0x7f; ++Significand)
100+
Failed += CheckBfloat16(Sign, Exponent, Significand);
101+
102+
return Failed;
103+
}

0 commit comments

Comments
 (0)