1515#ifndef FUSILLI_SUPPORT_FLOAT_TYPES_H
1616#define FUSILLI_SUPPORT_FLOAT_TYPES_H
1717
18+ #include < bit>
1819#include < cmath>
1920#include < cstdint>
20- #include < cstring>
2121#include < limits>
2222
2323namespace fusilli {
2424
2525// IEEE 754 half-precision floating point (Float16)
2626// Format: 1 sign bit, 5 exponent bits, 10 mantissa bits
27+ //
28+ // This type provides implicit conversions to/from float, allowing seamless
29+ // interoperability with float arithmetic. All operations are performed in
30+ // float precision through these conversions.
2731struct Float16 {
2832 int16_t data;
2933
30- Float16 () : data(0 ) {}
31- explicit Float16 (int16_t raw) : data(raw) {}
34+ constexpr Float16 () : data(0 ) {}
3235
33- // Construct from float
34- explicit Float16 (float f) { data = floatToFp16Bits (f); }
36+ // Construct from float (handles double via implicit conversion)
37+ constexpr Float16 (float f) : data( floatToFp16Bits(f)) { }
3538
3639 // Convert to float
37- float toFloat () const { return fp16BitsToFloat (data); }
38-
39- // Implicit conversion to float for arithmetic
40- explicit operator float () const { return toFloat (); }
41-
42- // Arithmetic operators (perform math in float32)
43- Float16 operator +(const Float16 &other) const {
44- return Float16 (toFloat () + other.toFloat ());
45- }
40+ constexpr float toFloat () const { return fp16BitsToFloat (data); }
4641
47- Float16 operator -( const Float16 &other) const {
48- return Float16 ( toFloat () - other. toFloat ());
49- }
42+ // Implicit conversion to float for seamless interoperability
43+ // Arithmetic and comparisons work through this conversion
44+ constexpr operator float () const { return toFloat (); }
5045
51- Float16 operator *(const Float16 &other) const {
52- return Float16 (toFloat () * other.toFloat ());
53- }
54-
55- Float16 operator /(const Float16 &other) const {
56- return Float16 (toFloat () / other.toFloat ());
57- }
58-
59- Float16 operator -() const { return Float16 (-toFloat ()); }
46+ // Unary negation
47+ constexpr Float16 operator -() const { return Float16 (-toFloat ()); }
6048
6149 // Compound assignment operators
62- Float16 &operator +=(const Float16 & other) {
63- *this = * this + other;
50+ constexpr Float16 &operator +=(Float16 other) {
51+ *this = Float16 ( toFloat () + other. toFloat ()) ;
6452 return *this ;
6553 }
6654
67- Float16 &operator -=(const Float16 & other) {
68- *this = * this - other;
55+ constexpr Float16 &operator -=(Float16 other) {
56+ *this = Float16 ( toFloat () - other. toFloat ()) ;
6957 return *this ;
7058 }
7159
72- Float16 &operator *=(const Float16 & other) {
73- *this = * this * other;
60+ constexpr Float16 &operator *=(Float16 other) {
61+ *this = Float16 ( toFloat () * other. toFloat ()) ;
7462 return *this ;
7563 }
7664
77- Float16 &operator /=(const Float16 & other) {
78- *this = * this / other;
65+ constexpr Float16 &operator /=(Float16 other) {
66+ *this = Float16 ( toFloat () / other. toFloat ()) ;
7967 return *this ;
8068 }
8169
82- // Comparison operators
83- bool operator ==(const Float16 &other) const {
84- return toFloat () == other.toFloat ();
85- }
86-
87- bool operator !=(const Float16 &other) const {
88- return toFloat () != other.toFloat ();
89- }
90-
91- bool operator <(const Float16 &other) const {
92- return toFloat () < other.toFloat ();
93- }
94-
95- bool operator <=(const Float16 &other) const {
96- return toFloat () <= other.toFloat ();
97- }
98-
99- bool operator >(const Float16 &other) const {
100- return toFloat () > other.toFloat ();
101- }
102-
103- bool operator >=(const Float16 &other) const {
104- return toFloat () >= other.toFloat ();
105- }
106-
10770 // Create from raw bits
108- static Float16 fromBits (int16_t bits) {
71+ static constexpr Float16 fromBits (int16_t bits) {
10972 Float16 result;
11073 result.data = bits;
11174 return result;
11275 }
11376
11477 // Get raw bits
115- int16_t toBits () const { return data; }
78+ constexpr int16_t toBits () const { return data; }
11679
11780private:
118- static int16_t floatToFp16Bits (float f) {
119- uint32_t bits;
120- std::memcpy (&bits, &f, sizeof (bits));
81+ static constexpr int16_t floatToFp16Bits (float f) {
82+ uint32_t bits = std::bit_cast<uint32_t >(f);
12183
12284 uint32_t sign = (bits >> 31 ) & 0x1 ;
12385 int32_t exp = ((bits >> 23 ) & 0xFF ) - 127 ;
@@ -180,7 +142,7 @@ struct Float16 {
180142 return static_cast <int16_t >((sign << 15 ) | (fp16Exp << 10 ) | fp16Mantissa);
181143 }
182144
183- static float fp16BitsToFloat (int16_t bits) {
145+ static constexpr float fp16BitsToFloat (int16_t bits) {
184146 uint16_t ubits = static_cast <uint16_t >(bits);
185147 uint32_t sign = (ubits >> 15 ) & 0x1 ;
186148 uint32_t exp = (ubits >> 10 ) & 0x1F ;
@@ -209,109 +171,69 @@ struct Float16 {
209171 result = (sign << 31 ) | ((exp + 127 - 15 ) << 23 ) | (mantissa << 13 );
210172 }
211173
212- float f;
213- std::memcpy (&f, &result, sizeof (f));
214- return f;
174+ return std::bit_cast<float >(result);
215175 }
216176};
217177
218178// Brain floating point (BFloat16)
219179// Format: 1 sign bit, 8 exponent bits, 7 mantissa bits
220180// Same exponent range as float32, just truncated mantissa
181+ //
182+ // This type provides implicit conversions to/from float, allowing seamless
183+ // interoperability with float arithmetic. All operations are performed in
184+ // float precision through these conversions.
221185struct BFloat16 {
222186 int16_t data;
223187
224- BFloat16 () : data(0 ) {}
225- explicit BFloat16 (int16_t raw) : data(raw) {}
188+ constexpr BFloat16 () : data(0 ) {}
226189
227- // Construct from float
228- explicit BFloat16 (float f) { data = floatToBf16Bits (f); }
190+ // Construct from float (handles double via implicit conversion)
191+ constexpr BFloat16 (float f) : data( floatToBf16Bits(f)) { }
229192
230193 // Convert to float
231- float toFloat () const { return bf16BitsToFloat (data); }
232-
233- // Implicit conversion to float for arithmetic
234- explicit operator float () const { return toFloat (); }
235-
236- // Arithmetic operators (perform math in float32)
237- BFloat16 operator +(const BFloat16 &other) const {
238- return BFloat16 (toFloat () + other.toFloat ());
239- }
194+ constexpr float toFloat () const { return bf16BitsToFloat (data); }
240195
241- BFloat16 operator -( const BFloat16 &other) const {
242- return BFloat16 ( toFloat () - other. toFloat ());
243- }
196+ // Implicit conversion to float for seamless interoperability
197+ // Arithmetic and comparisons work through this conversion
198+ constexpr operator float () const { return toFloat (); }
244199
245- BFloat16 operator *(const BFloat16 &other) const {
246- return BFloat16 (toFloat () * other.toFloat ());
247- }
248-
249- BFloat16 operator /(const BFloat16 &other) const {
250- return BFloat16 (toFloat () / other.toFloat ());
251- }
252-
253- BFloat16 operator -() const { return BFloat16 (-toFloat ()); }
200+ // Unary negation
201+ constexpr BFloat16 operator -() const { return BFloat16 (-toFloat ()); }
254202
255203 // Compound assignment operators
256- BFloat16 &operator +=(const BFloat16 & other) {
257- *this = * this + other;
204+ constexpr BFloat16 &operator +=(BFloat16 other) {
205+ *this = BFloat16 ( toFloat () + other. toFloat ()) ;
258206 return *this ;
259207 }
260208
261- BFloat16 &operator -=(const BFloat16 & other) {
262- *this = * this - other;
209+ constexpr BFloat16 &operator -=(BFloat16 other) {
210+ *this = BFloat16 ( toFloat () - other. toFloat ()) ;
263211 return *this ;
264212 }
265213
266- BFloat16 &operator *=(const BFloat16 & other) {
267- *this = * this * other;
214+ constexpr BFloat16 &operator *=(BFloat16 other) {
215+ *this = BFloat16 ( toFloat () * other. toFloat ()) ;
268216 return *this ;
269217 }
270218
271- BFloat16 &operator /=(const BFloat16 & other) {
272- *this = * this / other;
219+ constexpr BFloat16 &operator /=(BFloat16 other) {
220+ *this = BFloat16 ( toFloat () / other. toFloat ()) ;
273221 return *this ;
274222 }
275223
276- // Comparison operators
277- bool operator ==(const BFloat16 &other) const {
278- return toFloat () == other.toFloat ();
279- }
280-
281- bool operator !=(const BFloat16 &other) const {
282- return toFloat () != other.toFloat ();
283- }
284-
285- bool operator <(const BFloat16 &other) const {
286- return toFloat () < other.toFloat ();
287- }
288-
289- bool operator <=(const BFloat16 &other) const {
290- return toFloat () <= other.toFloat ();
291- }
292-
293- bool operator >(const BFloat16 &other) const {
294- return toFloat () > other.toFloat ();
295- }
296-
297- bool operator >=(const BFloat16 &other) const {
298- return toFloat () >= other.toFloat ();
299- }
300-
301224 // Create from raw bits
302- static BFloat16 fromBits (int16_t bits) {
225+ static constexpr BFloat16 fromBits (int16_t bits) {
303226 BFloat16 result;
304227 result.data = bits;
305228 return result;
306229 }
307230
308231 // Get raw bits
309- int16_t toBits () const { return data; }
232+ constexpr int16_t toBits () const { return data; }
310233
311234private:
312- static int16_t floatToBf16Bits (float f) {
313- uint32_t bits;
314- std::memcpy (&bits, &f, sizeof (bits));
235+ static constexpr int16_t floatToBf16Bits (float f) {
236+ uint32_t bits = std::bit_cast<uint32_t >(f);
315237
316238 // Round to nearest even
317239 uint32_t rounding = 0x7FFF + ((bits >> 16 ) & 1 );
@@ -321,12 +243,10 @@ struct BFloat16 {
321243 return static_cast <int16_t >(bits >> 16 );
322244 }
323245
324- static float bf16BitsToFloat (int16_t bits) {
246+ static constexpr float bf16BitsToFloat (int16_t bits) {
325247 // bf16 is just the upper 16 bits of float32
326248 uint32_t result = static_cast <uint32_t >(static_cast <uint16_t >(bits)) << 16 ;
327- float f;
328- std::memcpy (&f, &result, sizeof (f));
329- return f;
249+ return std::bit_cast<float >(result);
330250 }
331251};
332252
0 commit comments