Skip to content

Commit 0e074f5

Browse files
laramielcopybara-github
authored andcommitted
Fix comparisons for mxfloat types.
mxfloat.h used the generic comparison from float8 which assumed that the sign bit was the msb causing most ordering comparisons involving negative values to fail. Add tests for mxfloat PiperOrigin-RevId: 837140127 Change-Id: Ie786d43f6d329f2b908d8e3e607fe05516309edd
1 parent de94f4f commit 0e074f5

File tree

3 files changed

+457
-22
lines changed

3 files changed

+457
-22
lines changed

tensorstore/util/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,16 @@ tensorstore_cc_test(
838838
"@net_sourceforge_half//:half",
839839
],
840840
)
841+
842+
tensorstore_cc_test(
843+
name = "mxfloat_test",
844+
srcs = ["mxfloat_test.cc"],
845+
deps = [
846+
":bfloat16",
847+
":mxfloat",
848+
"@abseil-cpp//absl/base",
849+
"@abseil-cpp//absl/strings",
850+
"@googletest//:gtest_main",
851+
"@net_sourceforge_half//:half",
852+
],
853+
)

tensorstore/util/float8.h

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <stdint.h>
2020

2121
#include <algorithm>
22+
#include <climits>
2223
#include <cmath>
2324
#include <limits>
2425
#include <ostream>
@@ -55,6 +56,15 @@ limitations under the License.
5556
#include <bit>
5657
#endif
5758

59+
// Strongly suggest to the compiler that these functions are inlined.
60+
#ifndef TENSORSTORE_FLOATINLINE
61+
#if defined(_MSC_VER)
62+
#define TENSORSTORE_FLOATINLINE __forceinline
63+
#else
64+
#define TENSORSTORE_FLOATINLINE inline
65+
#endif
66+
#endif // TENSORSTORE_FLOATINLINE
67+
5868
namespace tensorstore {
5969
namespace float8_internal {
6070

@@ -74,6 +84,8 @@ class Float8Base {
7484
constexpr Float8Base(uint8_t rep, ConstructFromRepTag) : rep_{rep} {}
7585

7686
public:
87+
static constexpr int kBits = 8;
88+
7789
constexpr Float8Base() : rep_(0) {}
7890

7991
template <typename T>
@@ -123,10 +135,10 @@ class Float8Base {
123135

124136
// Conversions allowing saturation and truncation.
125137
template <bool kSaturate = false, bool kTruncate = false, typename From>
126-
static Derived ConvertFrom(From from);
138+
static inline Derived ConvertFrom(From from);
127139

128140
template <typename To, bool kSaturate = false, bool kTruncate = false>
129-
static To ConvertTo(const Derived& from);
141+
static inline To ConvertTo(const Derived& from);
130142

131143
// Operators via float32.
132144
Derived operator+(const Derived& other) const {
@@ -145,53 +157,53 @@ class Float8Base {
145157
return Derived{float{derived()} / float{other}};
146158
}
147159

148-
constexpr bool operator==(const Derived& other) const {
160+
constexpr inline bool operator==(const Derived& other) const {
149161
return Compare(derived(), other) == Ordering::kEquivalent;
150162
}
151163

152-
constexpr bool operator!=(const Derived& other) const {
164+
constexpr inline bool operator!=(const Derived& other) const {
153165
return Compare(derived(), other) != Ordering::kEquivalent;
154166
}
155167

156-
bool operator<(const Derived& other) const {
168+
TENSORSTORE_FLOATINLINE bool operator<(const Derived& other) const {
157169
return Compare(derived(), other) == Ordering::kLess;
158170
}
159171

160-
bool operator<=(const Derived& other) const {
172+
TENSORSTORE_FLOATINLINE bool operator<=(const Derived& other) const {
161173
return Compare(derived(), other) <= Ordering::kEquivalent;
162174
}
163175

164-
bool operator>(const Derived& other) const {
176+
TENSORSTORE_FLOATINLINE bool operator>(const Derived& other) const {
165177
return Compare(derived(), other) == Ordering::kGreater;
166178
}
167179

168-
bool operator>=(const Derived& other) const {
180+
TENSORSTORE_FLOATINLINE bool operator>=(const Derived& other) const {
169181
Ordering ordering = Compare(derived(), other);
170182
return ordering == Ordering::kGreater || ordering == Ordering::kEquivalent;
171183
}
172184

173185
// Compound assignment.
174-
Derived& operator+=(const Derived& other) {
186+
TENSORSTORE_FLOATINLINE Derived& operator+=(const Derived& other) {
175187
derived() = derived() + other;
176188
return derived();
177189
}
178190

179191
// for downsample_nditerable
180-
friend float operator+=(const float& a, Derived b) {
192+
friend TENSORSTORE_FLOATINLINE float operator+=(const float& a, Derived b) {
181193
return a + static_cast<float>(b);
182194
}
183195

184-
Derived& operator-=(const Derived& other) {
196+
TENSORSTORE_FLOATINLINE Derived& operator-=(const Derived& other) {
185197
derived() = derived() - other;
186198
return derived();
187199
}
188200

189-
Derived& operator*=(const Derived& other) {
201+
TENSORSTORE_FLOATINLINE Derived& operator*=(const Derived& other) {
190202
derived() = derived() * other;
191203
return derived();
192204
}
193205

194-
Derived& operator/=(const Derived& other) {
206+
TENSORSTORE_FLOATINLINE Derived& operator/=(const Derived& other) {
195207
derived() = derived() / other;
196208
return derived();
197209
}
@@ -223,14 +235,15 @@ class Float8Base {
223235
}
224236

225237
private:
226-
static std::pair<uint8_t, uint8_t> SignAndMagnitude(Derived x) {
238+
static TENSORSTORE_FLOATINLINE std::pair<uint8_t, uint8_t> SignAndMagnitude(
239+
Derived x) {
227240
const uint8_t x_abs_bits = absl::bit_cast<uint8_t>(abs(x));
228241
const uint8_t x_bits = absl::bit_cast<uint8_t>(x);
229-
const uint8_t x_sign = x_bits ^ x_abs_bits;
242+
const uint8_t x_sign = (x_bits ^ x_abs_bits) << (CHAR_BIT - Derived::kBits);
230243
return {x_sign, x_abs_bits};
231244
}
232-
static int8_t SignAndMagnitudeToTwosComplement(uint8_t sign,
233-
uint8_t magnitude) {
245+
static TENSORSTORE_FLOATINLINE int8_t
246+
SignAndMagnitudeToTwosComplement(uint8_t sign, uint8_t magnitude) {
234247
return magnitude ^ (static_cast<int8_t>(sign) < 0 ? -1 : 0);
235248
}
236249

@@ -492,9 +505,12 @@ struct numeric_limits_float8_base {
492505
static inline constexpr const bool is_integer = false;
493506
static inline constexpr const bool is_exact = false;
494507
static inline constexpr const bool has_quiet_NaN = true;
508+
// has_denorm and has_denorm_loss are deprecated in C++23.
509+
#if !defined(__cplusplus) || __cplusplus < 202302L
495510
static inline constexpr const std::float_denorm_style has_denorm =
496511
std::denorm_present;
497512
static inline constexpr const bool has_denorm_loss = false;
513+
#endif
498514
static inline constexpr const std::float_round_style round_style =
499515
std::round_to_nearest;
500516
static inline constexpr const bool is_bounded = true;
@@ -731,7 +747,8 @@ struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base {
731747
}
732748
static constexpr Float8e4m3fnuz infinity() {
733749
return Float8e4m3fnuz::FromRep(0x80);
734-
} // NaN.
750+
}
751+
// NaN.
735752
static constexpr Float8e4m3fnuz quiet_NaN() {
736753
return Float8e4m3fnuz::FromRep(0x80);
737754
}
@@ -946,10 +963,13 @@ constexpr inline bool isnan(const Float8e5m2fnuz& a) { return a.rep() == 0x80; }
946963

947964
template <typename Float8>
948965
constexpr inline bool(isinf)(const Float8Base<Float8>& a) {
949-
return std::numeric_limits<Float8>::has_infinity
950-
? abs(a.derived()).rep() ==
951-
std::numeric_limits<Float8>::infinity().rep()
952-
: false; // No inf representation.
966+
if constexpr (std::numeric_limits<Float8>::has_infinity) {
967+
return abs(a.derived()).rep() ==
968+
std::numeric_limits<Float8>::infinity().rep();
969+
} else {
970+
// No inf representation.
971+
return false;
972+
}
953973
}
954974

955975
template <typename Float8>

0 commit comments

Comments
 (0)