3232
3333#pragma once
3434
35+ #include < cuda/std/cstdint>
36+ #include < cuda/std/bit>
3537#include < cuda/std/cmath>
3638#include < cuda/std/type_traits>
3739
4143
4244namespace matx {
4345
46+ // Constexpr helper functions for float to half conversion
47+ namespace detail {
48+
49+ /* *
50+ * @brief Constexpr conversion from float to FP16 bits
51+ *
52+ * @param f Input float value
53+ * @return uint16_t FP16 bit representation
54+ */
55+ constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_fp16_bits (float f) {
56+ // Use bit_cast for constexpr context
57+ uint32_t bits = cuda::std::bit_cast<uint32_t >(f);
58+
59+ uint32_t sign = (bits >> 16 ) & 0x8000 ;
60+ int32_t exponent = static_cast <int32_t >(((bits >> 23 ) & 0xff )) - 127 + 15 ;
61+ uint32_t mantissa = (bits >> 13 ) & 0x3ff ;
62+
63+ // Handle special cases
64+ if (exponent <= 0 ) {
65+ // Subnormal or zero
66+ if (exponent < -10 ) {
67+ // Too small, flush to zero
68+ return static_cast <uint16_t >(sign);
69+ }
70+ // Subnormal
71+ mantissa = (mantissa | 0x400 ) >> (1 - exponent);
72+ return static_cast <uint16_t >(sign | mantissa);
73+ } else if (exponent >= 0x1f ) {
74+ // Overflow to infinity or NaN
75+ if (exponent == 0x1f + (127 - 15 ) && mantissa != 0 ) {
76+ // NaN
77+ return static_cast <uint16_t >(sign | 0x7e00 | (mantissa != 0 ? 0x200 : 0 ));
78+ }
79+ // Infinity
80+ return static_cast <uint16_t >(sign | 0x7c00 );
81+ }
82+
83+ return static_cast <uint16_t >(sign | (static_cast <uint32_t >(exponent) << 10 ) | mantissa);
84+ }
85+
86+ /* *
87+ * @brief Constexpr conversion from float to BF16 bits
88+ *
89+ * @param f Input float value
90+ * @return uint16_t BF16 bit representation (top 16 bits of float)
91+ */
92+ constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_bf16_bits (float f) {
93+ // BF16 is just the top 16 bits of a float32
94+ // With rounding to nearest even
95+ uint32_t bits = cuda::std::bit_cast<uint32_t >(f);
96+
97+ // Round to nearest even
98+ uint32_t rounding_bias = 0x00007FFF + ((bits >> 16 ) & 1 );
99+ bits += rounding_bias;
100+ uint16_t result = static_cast <uint16_t >(bits >> 16 );
101+
102+ return result;
103+ }
104+
105+ /* *
106+ * @brief Helper to convert float to half type at compile time
107+ *
108+ * @tparam T The target half type (__half or __nv_bfloat16)
109+ * @param f Input float value
110+ * @return T Half-precision value
111+ */
112+ template <typename T>
113+ constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T float_to_half_constexpr (float f) {
114+ if constexpr (cuda::std::is_same_v<T, __half>) {
115+ return cuda::std::bit_cast<__half>(float_to_fp16_bits (f));
116+ } else {
117+ return cuda::std::bit_cast<__nv_bfloat16>(float_to_bf16_bits (f));
118+ }
119+ }
120+
121+ } // namespace detail
122+
44123/* *
45124 * Template class for half precison numbers (__half and __nv_bfloat16). CUDA
46125 * does not have standardized classes/operators available on both host and
@@ -64,12 +143,49 @@ template <typename T> struct alignas(sizeof(T)) matxHalf {
64143 __MATX_INLINE__ matxHalf (const matxHalf<T> &x_) noexcept = default ;
65144
66145 /* *
67- * @brief Copy constructor from arbitrary type
146+ * @brief Constexpr constructor from float
147+ *
148+ * @param f Float value to convert
149+ */
150+ constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf (float f) noexcept
151+ : x (detail::float_to_half_constexpr<T>(f))
152+ {
153+ }
154+
155+ /* *
156+ * @brief Constexpr constructor from double
157+ *
158+ * @param d Double value to convert
159+ */
160+ constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf (double d) noexcept
161+ : x (detail::float_to_half_constexpr<T>(static_cast <float >(d)))
162+ {
163+ }
164+
165+ /* *
166+ * @brief Constructor from integral types (constexpr)
167+ *
168+ * @tparam T2 Integral type to copy from
169+ * @param x_ Value to copy
170+ */
171+ template <typename T2,
172+ cuda::std::enable_if_t <cuda::std::is_integral_v<T2>, int > = 0 >
173+ constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf (T2 x_) noexcept
174+ : x (detail::float_to_half_constexpr<T>(static_cast <float >(x_)))
175+ {
176+ }
177+
178+ /* *
179+ * @brief Copy constructor from arbitrary type (non-constexpr for non-arithmetic types)
68180 *
69181 * @tparam T2 Type to copy from
70182 * @param x_ Value to copy
71183 */
72- template <typename T2>
184+ template <typename T2,
185+ cuda::std::enable_if_t <
186+ !cuda::std::is_same_v<cuda::std::decay_t <T2>, float > &&
187+ !cuda::std::is_same_v<cuda::std::decay_t <T2>, double > &&
188+ !cuda::std::is_integral_v<T2>, int > = 0 >
73189 __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf (const T2 &x_) noexcept
74190 : x (static_cast <float >(x_))
75191 {
@@ -1316,3 +1432,38 @@ using matxFp16 = matxHalf<__half>; ///< Alias for fp16
13161432using matxBf16 = matxHalf<__nv_bfloat16>; // /< Alias for bf16
13171433
13181434}; // namespace matx
1435+
1436+ #ifndef __CUDACC_RTC__
1437+ // Add std::formatter specializations for matxFp16 and matxBf16
1438+ #include < format>
1439+
1440+ namespace std {
1441+
1442+ /* *
1443+ * @brief std::formatter specialization for matxFp16
1444+ *
1445+ * Enables matxFp16 to work with std::format by converting to float
1446+ */
1447+ template <>
1448+ struct formatter <matx::matxFp16> : formatter<float > {
1449+ template <typename FormatContext>
1450+ auto format (const matx::matxFp16& val, FormatContext& ctx) const {
1451+ return formatter<float >::format (static_cast <float >(val), ctx);
1452+ }
1453+ };
1454+
1455+ /* *
1456+ * @brief std::formatter specialization for matxBf16
1457+ *
1458+ * Enables matxBf16 to work with std::format by converting to float
1459+ */
1460+ template <>
1461+ struct formatter <matx::matxBf16> : formatter<float > {
1462+ template <typename FormatContext>
1463+ auto format (const matx::matxBf16& val, FormatContext& ctx) const {
1464+ return formatter<float >::format (static_cast <float >(val), ctx);
1465+ }
1466+ };
1467+
1468+ } // namespace std
1469+ #endif
0 commit comments