Skip to content

Commit b342280

Browse files
rsudermanclaude
andcommitted
[Support] Fix Float16/BFloat16 compilation issues with implicit conversions
Use std::bit_cast for constexpr support, remove ambiguous int16_t constructor (use fromBits() instead), and use implicit conversions to/from float to avoid operator ambiguity. Remove member binary operators since arithmetic works through float conversion. Co-Authored-By: Claude Opus 4.5 <[email protected]> Signed-off-by: Rob Suderman <[email protected]>
1 parent 8d28a64 commit b342280

File tree

1 file changed

+55
-135
lines changed

1 file changed

+55
-135
lines changed

include/fusilli/support/float_types.h

Lines changed: 55 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -15,109 +15,71 @@
1515
#ifndef FUSILLI_SUPPORT_FLOAT_TYPES_H
1616
#define FUSILLI_SUPPORT_FLOAT_TYPES_H
1717

18+
#include <bit>
1819
#include <cmath>
1920
#include <cstdint>
20-
#include <cstring>
2121
#include <limits>
2222

2323
namespace fusilli {
2424

2525
// IEEE 754 half-precision floating point (Float16)
2626
// Format: 1 sign bit, 5 exponent bits, 10 mantissa bits
27+
//
28+
// This type provides implicit conversions to/from float, allowing seamless
29+
// interoperability with float arithmetic. All operations are performed in
30+
// float precision through these conversions.
2731
struct Float16 {
2832
int16_t data;
2933

30-
Float16() : data(0) {}
31-
explicit Float16(int16_t raw) : data(raw) {}
34+
constexpr Float16() : data(0) {}
3235

33-
// Construct from float
34-
explicit Float16(float f) { data = floatToFp16Bits(f); }
36+
// Construct from float (handles double via implicit conversion)
37+
constexpr Float16(float f) : data(floatToFp16Bits(f)) {}
3538

3639
// Convert to float
37-
float toFloat() const { return fp16BitsToFloat(data); }
38-
39-
// Implicit conversion to float for arithmetic
40-
explicit operator float() const { return toFloat(); }
41-
42-
// Arithmetic operators (perform math in float32)
43-
Float16 operator+(const Float16 &other) const {
44-
return Float16(toFloat() + other.toFloat());
45-
}
40+
constexpr float toFloat() const { return fp16BitsToFloat(data); }
4641

47-
Float16 operator-(const Float16 &other) const {
48-
return Float16(toFloat() - other.toFloat());
49-
}
42+
// Implicit conversion to float for seamless interoperability
43+
// Arithmetic and comparisons work through this conversion
44+
constexpr operator float() const { return toFloat(); }
5045

51-
Float16 operator*(const Float16 &other) const {
52-
return Float16(toFloat() * other.toFloat());
53-
}
54-
55-
Float16 operator/(const Float16 &other) const {
56-
return Float16(toFloat() / other.toFloat());
57-
}
58-
59-
Float16 operator-() const { return Float16(-toFloat()); }
46+
// Unary negation
47+
constexpr Float16 operator-() const { return Float16(-toFloat()); }
6048

6149
// Compound assignment operators
62-
Float16 &operator+=(const Float16 &other) {
63-
*this = *this + other;
50+
constexpr Float16 &operator+=(Float16 other) {
51+
*this = Float16(toFloat() + other.toFloat());
6452
return *this;
6553
}
6654

67-
Float16 &operator-=(const Float16 &other) {
68-
*this = *this - other;
55+
constexpr Float16 &operator-=(Float16 other) {
56+
*this = Float16(toFloat() - other.toFloat());
6957
return *this;
7058
}
7159

72-
Float16 &operator*=(const Float16 &other) {
73-
*this = *this * other;
60+
constexpr Float16 &operator*=(Float16 other) {
61+
*this = Float16(toFloat() * other.toFloat());
7462
return *this;
7563
}
7664

77-
Float16 &operator/=(const Float16 &other) {
78-
*this = *this / other;
65+
constexpr Float16 &operator/=(Float16 other) {
66+
*this = Float16(toFloat() / other.toFloat());
7967
return *this;
8068
}
8169

82-
// Comparison operators
83-
bool operator==(const Float16 &other) const {
84-
return toFloat() == other.toFloat();
85-
}
86-
87-
bool operator!=(const Float16 &other) const {
88-
return toFloat() != other.toFloat();
89-
}
90-
91-
bool operator<(const Float16 &other) const {
92-
return toFloat() < other.toFloat();
93-
}
94-
95-
bool operator<=(const Float16 &other) const {
96-
return toFloat() <= other.toFloat();
97-
}
98-
99-
bool operator>(const Float16 &other) const {
100-
return toFloat() > other.toFloat();
101-
}
102-
103-
bool operator>=(const Float16 &other) const {
104-
return toFloat() >= other.toFloat();
105-
}
106-
10770
// Create from raw bits
108-
static Float16 fromBits(int16_t bits) {
71+
static constexpr Float16 fromBits(int16_t bits) {
10972
Float16 result;
11073
result.data = bits;
11174
return result;
11275
}
11376

11477
// Get raw bits
115-
int16_t toBits() const { return data; }
78+
constexpr int16_t toBits() const { return data; }
11679

11780
private:
118-
static int16_t floatToFp16Bits(float f) {
119-
uint32_t bits;
120-
std::memcpy(&bits, &f, sizeof(bits));
81+
static constexpr int16_t floatToFp16Bits(float f) {
82+
uint32_t bits = std::bit_cast<uint32_t>(f);
12183

12284
uint32_t sign = (bits >> 31) & 0x1;
12385
int32_t exp = ((bits >> 23) & 0xFF) - 127;
@@ -180,7 +142,7 @@ struct Float16 {
180142
return static_cast<int16_t>((sign << 15) | (fp16Exp << 10) | fp16Mantissa);
181143
}
182144

183-
static float fp16BitsToFloat(int16_t bits) {
145+
static constexpr float fp16BitsToFloat(int16_t bits) {
184146
uint16_t ubits = static_cast<uint16_t>(bits);
185147
uint32_t sign = (ubits >> 15) & 0x1;
186148
uint32_t exp = (ubits >> 10) & 0x1F;
@@ -209,109 +171,69 @@ struct Float16 {
209171
result = (sign << 31) | ((exp + 127 - 15) << 23) | (mantissa << 13);
210172
}
211173

212-
float f;
213-
std::memcpy(&f, &result, sizeof(f));
214-
return f;
174+
return std::bit_cast<float>(result);
215175
}
216176
};
217177

218178
// Brain floating point (BFloat16)
219179
// Format: 1 sign bit, 8 exponent bits, 7 mantissa bits
220180
// Same exponent range as float32, just truncated mantissa
181+
//
182+
// This type provides implicit conversions to/from float, allowing seamless
183+
// interoperability with float arithmetic. All operations are performed in
184+
// float precision through these conversions.
221185
struct BFloat16 {
222186
int16_t data;
223187

224-
BFloat16() : data(0) {}
225-
explicit BFloat16(int16_t raw) : data(raw) {}
188+
constexpr BFloat16() : data(0) {}
226189

227-
// Construct from float
228-
explicit BFloat16(float f) { data = floatToBf16Bits(f); }
190+
// Construct from float (handles double via implicit conversion)
191+
constexpr BFloat16(float f) : data(floatToBf16Bits(f)) {}
229192

230193
// Convert to float
231-
float toFloat() const { return bf16BitsToFloat(data); }
232-
233-
// Implicit conversion to float for arithmetic
234-
explicit operator float() const { return toFloat(); }
235-
236-
// Arithmetic operators (perform math in float32)
237-
BFloat16 operator+(const BFloat16 &other) const {
238-
return BFloat16(toFloat() + other.toFloat());
239-
}
194+
constexpr float toFloat() const { return bf16BitsToFloat(data); }
240195

241-
BFloat16 operator-(const BFloat16 &other) const {
242-
return BFloat16(toFloat() - other.toFloat());
243-
}
196+
// Implicit conversion to float for seamless interoperability
197+
// Arithmetic and comparisons work through this conversion
198+
constexpr operator float() const { return toFloat(); }
244199

245-
BFloat16 operator*(const BFloat16 &other) const {
246-
return BFloat16(toFloat() * other.toFloat());
247-
}
248-
249-
BFloat16 operator/(const BFloat16 &other) const {
250-
return BFloat16(toFloat() / other.toFloat());
251-
}
252-
253-
BFloat16 operator-() const { return BFloat16(-toFloat()); }
200+
// Unary negation
201+
constexpr BFloat16 operator-() const { return BFloat16(-toFloat()); }
254202

255203
// Compound assignment operators
256-
BFloat16 &operator+=(const BFloat16 &other) {
257-
*this = *this + other;
204+
constexpr BFloat16 &operator+=(BFloat16 other) {
205+
*this = BFloat16(toFloat() + other.toFloat());
258206
return *this;
259207
}
260208

261-
BFloat16 &operator-=(const BFloat16 &other) {
262-
*this = *this - other;
209+
constexpr BFloat16 &operator-=(BFloat16 other) {
210+
*this = BFloat16(toFloat() - other.toFloat());
263211
return *this;
264212
}
265213

266-
BFloat16 &operator*=(const BFloat16 &other) {
267-
*this = *this * other;
214+
constexpr BFloat16 &operator*=(BFloat16 other) {
215+
*this = BFloat16(toFloat() * other.toFloat());
268216
return *this;
269217
}
270218

271-
BFloat16 &operator/=(const BFloat16 &other) {
272-
*this = *this / other;
219+
constexpr BFloat16 &operator/=(BFloat16 other) {
220+
*this = BFloat16(toFloat() / other.toFloat());
273221
return *this;
274222
}
275223

276-
// Comparison operators
277-
bool operator==(const BFloat16 &other) const {
278-
return toFloat() == other.toFloat();
279-
}
280-
281-
bool operator!=(const BFloat16 &other) const {
282-
return toFloat() != other.toFloat();
283-
}
284-
285-
bool operator<(const BFloat16 &other) const {
286-
return toFloat() < other.toFloat();
287-
}
288-
289-
bool operator<=(const BFloat16 &other) const {
290-
return toFloat() <= other.toFloat();
291-
}
292-
293-
bool operator>(const BFloat16 &other) const {
294-
return toFloat() > other.toFloat();
295-
}
296-
297-
bool operator>=(const BFloat16 &other) const {
298-
return toFloat() >= other.toFloat();
299-
}
300-
301224
// Create from raw bits
302-
static BFloat16 fromBits(int16_t bits) {
225+
static constexpr BFloat16 fromBits(int16_t bits) {
303226
BFloat16 result;
304227
result.data = bits;
305228
return result;
306229
}
307230

308231
// Get raw bits
309-
int16_t toBits() const { return data; }
232+
constexpr int16_t toBits() const { return data; }
310233

311234
private:
312-
static int16_t floatToBf16Bits(float f) {
313-
uint32_t bits;
314-
std::memcpy(&bits, &f, sizeof(bits));
235+
static constexpr int16_t floatToBf16Bits(float f) {
236+
uint32_t bits = std::bit_cast<uint32_t>(f);
315237

316238
// Round to nearest even
317239
uint32_t rounding = 0x7FFF + ((bits >> 16) & 1);
@@ -321,12 +243,10 @@ struct BFloat16 {
321243
return static_cast<int16_t>(bits >> 16);
322244
}
323245

324-
static float bf16BitsToFloat(int16_t bits) {
246+
static constexpr float bf16BitsToFloat(int16_t bits) {
325247
// bf16 is just the upper 16 bits of float32
326248
uint32_t result = static_cast<uint32_t>(static_cast<uint16_t>(bits)) << 16;
327-
float f;
328-
std::memcpy(&f, &result, sizeof(f));
329-
return f;
249+
return std::bit_cast<float>(result);
330250
}
331251
};
332252

0 commit comments

Comments
 (0)