|
4 | 4 | // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. |
5 | 5 |
|
6 | 6 | #include <c10/macros/Macros.h> |
| 7 | +#include <c10/util/bit_cast.h> |
7 | 8 | #include <cmath> |
8 | 9 | #include <cstdint> |
9 | 10 | #include <cstring> |
@@ -67,13 +68,123 @@ inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { |
67 | 68 | #endif |
68 | 69 | return UINT16_C(0x7FC0); |
69 | 70 | } else { |
70 | | - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
71 | | - union { |
72 | | - uint32_t U32; // NOLINT(facebook-hte-BadMemberName) |
73 | | - float F32; // NOLINT(facebook-hte-BadMemberName) |
74 | | - }; |
| 71 | + const uint32_t U32 = c10::bit_cast<uint32_t>(src); |
| 72 | + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); |
| 73 | + return static_cast<uint16_t>((U32 + rounding_bias) >> 16); |
| 74 | + } |
| 75 | +} |
| 76 | +} // namespace detail |
| 77 | + |
| 78 | +struct alignas(2) BFloat16 { |
| 79 | + uint16_t x; |
| 80 | + |
| 81 | + // HIP wants __host__ __device__ tag, CUDA does not |
| 82 | +#if defined(USE_ROCM) && defined(__HIPCC__) |
| 83 | + C10_HOST_DEVICE BFloat16() = default; |
| 84 | +#else |
| 85 | + BFloat16() = default; |
| 86 | +#endif |
| 87 | + |
| 88 | + struct from_bits_t {}; |
| 89 | + static constexpr C10_HOST_DEVICE from_bits_t from_bits() { |
| 90 | + return from_bits_t(); |
| 91 | + } |
| 92 | + |
| 93 | + constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) |
| 94 | + : x(bits) {} |
| 95 | + /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); |
| 96 | + inline C10_HOST_DEVICE operator float() const; |
| 97 | + |
| 98 | +#if defined(__CUDACC__) && !defined(USE_ROCM) |
| 99 | + inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); |
| 100 | + explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; |
| 101 | +#endif |
| 102 | + |
| 103 | +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
| 104 | + inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); |
| 105 | + explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; |
| 106 | +#endif |
| 107 | +}; |
| 108 | + |
| 109 | +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { |
| 110 | + out << (float)value; |
| 111 | + return out; |
| 112 | +} |
| 113 | + |
| 114 | +} // namespace c10 |
| 115 | + |
| 116 | +#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep |
| 117 | +#pragma once |
| 118 | + |
| 119 | +// Defines the bloat16 type (brain floating-point). This representation uses |
| 120 | +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. |
| 121 | + |
| 122 | +#include <c10/macros/Macros.h> |
| 123 | +#include <c10/util/bit_cast.h> |
| 124 | +#include <cmath> |
| 125 | +#include <cstdint> |
| 126 | +#include <cstring> |
| 127 | +#include <iosfwd> |
| 128 | +#include <ostream> |
| 129 | + |
| 130 | +#if defined(__CUDACC__) && !defined(USE_ROCM) |
| 131 | +#include <cuda_bf16.h> |
| 132 | +#endif |
| 133 | + |
| 134 | +#if defined(CL_SYCL_LANGUAGE_VERSION) |
| 135 | +#include <CL/sycl.hpp> // for SYCL 1.2.1 |
| 136 | +#elif defined(SYCL_LANGUAGE_VERSION) |
| 137 | +#include <sycl/sycl.hpp> // for SYCL 2020 |
| 138 | +#endif |
| 139 | + |
| 140 | +namespace c10 { |
| 141 | + |
| 142 | +namespace detail { |
| 143 | +inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { |
| 144 | + float res = 0; |
| 145 | + uint32_t tmp = src; |
| 146 | + tmp <<= 16; |
| 147 | + |
| 148 | +#if defined(USE_ROCM) && defined(__HIPCC__) |
| 149 | + float* tempRes; |
| 150 | + |
| 151 | + // We should be using memcpy in order to respect the strict aliasing rule |
| 152 | + // but it fails in the HIP environment. |
| 153 | + tempRes = reinterpret_cast<float*>(&tmp); |
| 154 | + res = *tempRes; |
| 155 | +#else |
| 156 | + std::memcpy(&res, &tmp, sizeof(tmp)); |
| 157 | +#endif |
75 | 158 |
|
76 | | - F32 = src; |
| 159 | + return res; |
| 160 | +} |
| 161 | + |
| 162 | +inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { |
| 163 | + uint32_t res = 0; |
| 164 | + |
| 165 | +#if defined(USE_ROCM) && defined(__HIPCC__) |
| 166 | + // We should be using memcpy in order to respect the strict aliasing rule |
| 167 | + // but it fails in the HIP environment. |
| 168 | + uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); |
| 169 | + res = *tempRes; |
| 170 | +#else |
| 171 | + std::memcpy(&res, &src, sizeof(res)); |
| 172 | +#endif |
| 173 | + |
| 174 | + return res >> 16; |
| 175 | +} |
| 176 | + |
| 177 | +inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { |
| 178 | +#if defined(USE_ROCM) && defined(__HIPCC__) |
| 179 | + if (src != src) { |
| 180 | +#elif defined(_MSC_VER) |
| 181 | + if (isnan(src)) { |
| 182 | +#else |
| 183 | + if (std::isnan(src)) { |
| 184 | +#endif |
| 185 | + return UINT16_C(0x7FC0); |
| 186 | + } else { |
| 187 | + const uint32_t U32 = c10::bit_cast<uint32_t>(src); |
77 | 188 | uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); |
78 | 189 | return static_cast<uint16_t>((U32 + rounding_bias) >> 16); |
79 | 190 | } |
@@ -111,9 +222,7 @@ struct alignas(2) BFloat16 { |
111 | 222 | #endif |
112 | 223 | }; |
113 | 224 |
|
114 | | -C10_API inline std::ostream& operator<<( |
115 | | - std::ostream& out, |
116 | | - const BFloat16& value) { |
| 225 | +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { |
117 | 226 | out << (float)value; |
118 | 227 | return out; |
119 | 228 | } |
|
0 commit comments