|
1 |
| -#pragma once |
2 |
| - |
3 |
| -#include <c10/macros/Macros.h> |
4 |
| -#include <c10/util/bit_cast.h> |
5 |
| - |
6 |
| -#include <limits> |
7 |
| - |
8 |
| -C10_CLANG_DIAGNOSTIC_PUSH() |
9 |
| -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") |
10 |
| -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") |
11 |
| -#endif |
12 |
| - |
13 |
| -#if defined(CL_SYCL_LANGUAGE_VERSION) |
14 |
| -#include <CL/sycl.hpp> // for SYCL 1.2.1 |
15 |
| -#elif defined(SYCL_LANGUAGE_VERSION) |
16 |
| -#include <sycl/sycl.hpp> // for SYCL 2020 |
17 |
| -#endif |
18 |
| - |
19 |
| -namespace c10 { |
20 |
| - |
21 |
| -/// Constructors |
22 |
| -inline C10_HOST_DEVICE BFloat16::BFloat16(float value) |
23 |
| - : |
24 |
| -#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ |
25 |
| - __CUDA_ARCH__ >= 800 |
26 |
| - x(__bfloat16_as_ushort(__float2bfloat16(value))) |
27 |
| -#elif defined(__SYCL_DEVICE_ONLY__) && \ |
28 |
| - defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
29 |
| - x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value))) |
30 |
| -#else |
31 |
| - // RNE by default |
32 |
| - x(detail::round_to_nearest_even(value)) |
33 |
| -#endif |
34 |
| -{ |
35 |
| -} |
36 |
| - |
37 |
| -/// Implicit conversions |
38 |
| -inline C10_HOST_DEVICE BFloat16::operator float() const { |
39 |
| -#if defined(__CUDACC__) && !defined(USE_ROCM) |
40 |
| - return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x)); |
41 |
| -#elif defined(__SYCL_DEVICE_ONLY__) && \ |
42 |
| - defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
43 |
| - return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x)); |
44 |
| -#else |
45 |
| - return detail::f32_from_bits(x); |
46 |
| -#endif |
47 |
| -} |
48 |
| - |
49 |
| -#if defined(__CUDACC__) && !defined(USE_ROCM) |
50 |
| -inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { |
51 |
| - x = *reinterpret_cast<const unsigned short*>(&value); |
52 |
| -} |
53 |
| -inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { |
54 |
| - return *reinterpret_cast<const __nv_bfloat16*>(&x); |
55 |
| -} |
56 |
| -#endif |
57 |
| - |
58 |
| -#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
59 |
| -inline C10_HOST_DEVICE BFloat16::BFloat16( |
60 |
| - const sycl::ext::oneapi::bfloat16& value) { |
61 |
| - x = *reinterpret_cast<const unsigned short*>(&value); |
62 |
| -} |
63 |
| -inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { |
64 |
| - return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x); |
65 |
| -} |
66 |
| -#endif |
67 |
| - |
68 |
| -// CUDA intrinsics |
69 |
| - |
70 |
| -#if defined(__CUDACC__) || defined(__HIPCC__) |
71 |
| -inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { |
72 |
| -#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
73 |
| - return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr)); |
74 |
| -#else |
75 |
| - return *ptr; |
76 |
| -#endif |
77 |
| -} |
78 |
| -#endif |
79 |
| - |
80 |
| -/// Arithmetic |
81 |
| - |
82 |
| -inline C10_HOST_DEVICE BFloat16 |
83 |
| -operator+(const BFloat16& a, const BFloat16& b) { |
84 |
| - return static_cast<float>(a) + static_cast<float>(b); |
85 |
| -} |
86 |
| - |
87 |
| -inline C10_HOST_DEVICE BFloat16 |
88 |
| -operator-(const BFloat16& a, const BFloat16& b) { |
89 |
| - return static_cast<float>(a) - static_cast<float>(b); |
90 |
| -} |
91 |
| - |
92 |
| -inline C10_HOST_DEVICE BFloat16 |
93 |
| -operator*(const BFloat16& a, const BFloat16& b) { |
94 |
| - return static_cast<float>(a) * static_cast<float>(b); |
95 |
| -} |
96 |
| - |
97 |
| -inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) |
98 |
| - __ubsan_ignore_float_divide_by_zero__ { |
99 |
| - return static_cast<float>(a) / static_cast<float>(b); |
100 |
| -} |
101 |
| - |
102 |
| -inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { |
103 |
| - return -static_cast<float>(a); |
104 |
| -} |
105 |
| - |
106 |
| -inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) { |
107 |
| - a = a + b; |
108 |
| - return a; |
109 |
| -} |
110 |
| - |
111 |
| -inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) { |
112 |
| - a = a - b; |
113 |
| - return a; |
114 |
| -} |
115 |
| - |
116 |
| -inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) { |
117 |
| - a = a * b; |
118 |
| - return a; |
119 |
| -} |
120 |
| - |
121 |
| -inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { |
122 |
| - a = a / b; |
123 |
| - return a; |
124 |
| -} |
125 |
| - |
126 |
| -inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { |
127 |
| - a.x = a.x | b.x; |
128 |
| - return a; |
129 |
| -} |
130 |
| - |
131 |
| -inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { |
132 |
| - a.x = a.x ^ b.x; |
133 |
| - return a; |
134 |
| -} |
135 |
| - |
136 |
| -inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { |
137 |
| - a.x = a.x & b.x; |
138 |
| - return a; |
139 |
| -} |
140 |
| - |
141 |
| -/// Arithmetic with floats |
142 |
| - |
143 |
| -inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { |
144 |
| - return static_cast<float>(a) + b; |
145 |
| -} |
146 |
| -inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) { |
147 |
| - return static_cast<float>(a) - b; |
148 |
| -} |
149 |
| -inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) { |
150 |
| - return static_cast<float>(a) * b; |
151 |
| -} |
152 |
| -inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) { |
153 |
| - return static_cast<float>(a) / b; |
154 |
| -} |
155 |
| - |
156 |
| -inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) { |
157 |
| - return a + static_cast<float>(b); |
158 |
| -} |
159 |
| -inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) { |
160 |
| - return a - static_cast<float>(b); |
161 |
| -} |
162 |
| -inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) { |
163 |
| - return a * static_cast<float>(b); |
164 |
| -} |
165 |
| -inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) { |
166 |
| - return a / static_cast<float>(b); |
167 |
| -} |
168 |
| - |
169 |
| -inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { |
170 |
| - return a += static_cast<float>(b); |
171 |
| -} |
172 |
| -inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { |
173 |
| - return a -= static_cast<float>(b); |
174 |
| -} |
175 |
| -inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { |
176 |
| - return a *= static_cast<float>(b); |
177 |
| -} |
178 |
| -inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { |
179 |
| - return a /= static_cast<float>(b); |
180 |
| -} |
181 |
| - |
182 |
| -/// Arithmetic with doubles |
183 |
| - |
184 |
| -inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) { |
185 |
| - return static_cast<double>(a) + b; |
186 |
| -} |
187 |
| -inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) { |
188 |
| - return static_cast<double>(a) - b; |
189 |
| -} |
190 |
| -inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) { |
191 |
| - return static_cast<double>(a) * b; |
192 |
| -} |
193 |
| -inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) { |
194 |
| - return static_cast<double>(a) / b; |
195 |
| -} |
196 |
| - |
197 |
| -inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) { |
198 |
| - return a + static_cast<double>(b); |
199 |
| -} |
200 |
| -inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) { |
201 |
| - return a - static_cast<double>(b); |
202 |
| -} |
203 |
| -inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) { |
204 |
| - return a * static_cast<double>(b); |
205 |
| -} |
206 |
| -inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) { |
207 |
| - return a / static_cast<double>(b); |
208 |
| -} |
209 |
| - |
210 |
| -/// Arithmetic with ints |
211 |
| - |
212 |
| -inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { |
213 |
| - return a + static_cast<BFloat16>(b); |
214 |
| -} |
215 |
| -inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { |
216 |
| - return a - static_cast<BFloat16>(b); |
217 |
| -} |
218 |
| -inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { |
219 |
| - return a * static_cast<BFloat16>(b); |
220 |
| -} |
221 |
| -inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { |
222 |
| - return a / static_cast<BFloat16>(b); |
223 |
| -} |
224 |
| - |
225 |
| -inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { |
226 |
| - return static_cast<BFloat16>(a) + b; |
227 |
| -} |
228 |
| -inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { |
229 |
| - return static_cast<BFloat16>(a) - b; |
230 |
| -} |
231 |
| -inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { |
232 |
| - return static_cast<BFloat16>(a) * b; |
233 |
| -} |
234 |
| -inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { |
235 |
| - return static_cast<BFloat16>(a) / b; |
236 |
| -} |
237 |
| - |
238 |
| -//// Arithmetic with int64_t |
239 |
| - |
240 |
| -inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { |
241 |
| - return a + static_cast<BFloat16>(b); |
242 |
| -} |
243 |
| -inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { |
244 |
| - return a - static_cast<BFloat16>(b); |
245 |
| -} |
246 |
| -inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { |
247 |
| - return a * static_cast<BFloat16>(b); |
248 |
| -} |
249 |
| -inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { |
250 |
| - return a / static_cast<BFloat16>(b); |
251 |
| -} |
252 |
| - |
253 |
| -inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { |
254 |
| - return static_cast<BFloat16>(a) + b; |
255 |
| -} |
256 |
| -inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { |
257 |
| - return static_cast<BFloat16>(a) - b; |
258 |
| -} |
259 |
| -inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { |
260 |
| - return static_cast<BFloat16>(a) * b; |
261 |
| -} |
262 |
| -inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { |
263 |
| - return static_cast<BFloat16>(a) / b; |
264 |
| -} |
265 |
| - |
266 |
| -// Overloading < and > operators, because std::max and std::min use them. |
267 |
| - |
268 |
| -inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { |
269 |
| - return float(lhs) > float(rhs); |
270 |
| -} |
271 |
| - |
272 |
| -inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { |
273 |
| - return float(lhs) < float(rhs); |
274 |
| -} |
275 |
| - |
276 |
| -} // namespace c10 |
277 |
| - |
278 |
| -namespace std { |
279 |
| - |
280 |
| -template <> |
281 |
| -class numeric_limits<c10::BFloat16> { |
282 |
| - public: |
283 |
| - static constexpr bool is_signed = true; |
284 |
| - static constexpr bool is_specialized = true; |
285 |
| - static constexpr bool is_integer = false; |
286 |
| - static constexpr bool is_exact = false; |
287 |
| - static constexpr bool has_infinity = true; |
288 |
| - static constexpr bool has_quiet_NaN = true; |
289 |
| - static constexpr bool has_signaling_NaN = true; |
290 |
| - static constexpr auto has_denorm = numeric_limits<float>::has_denorm; |
291 |
| - static constexpr auto has_denorm_loss = |
292 |
| - numeric_limits<float>::has_denorm_loss; |
293 |
| - static constexpr auto round_style = numeric_limits<float>::round_style; |
294 |
| - static constexpr bool is_iec559 = false; |
295 |
| - static constexpr bool is_bounded = true; |
296 |
| - static constexpr bool is_modulo = false; |
297 |
| - static constexpr int digits = 8; |
298 |
| - static constexpr int digits10 = 2; |
299 |
| - static constexpr int max_digits10 = 4; |
300 |
| - static constexpr int radix = 2; |
301 |
| - static constexpr int min_exponent = -125; |
302 |
| - static constexpr int min_exponent10 = -37; |
303 |
| - static constexpr int max_exponent = 128; |
304 |
| - static constexpr int max_exponent10 = 38; |
305 |
| - static constexpr auto traps = numeric_limits<float>::traps; |
306 |
| - static constexpr auto tinyness_before = |
307 |
| - numeric_limits<float>::tinyness_before; |
308 |
| - |
309 |
| - static constexpr c10::BFloat16 min() { |
310 |
| - return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); |
311 |
| - } |
312 |
| - static constexpr c10::BFloat16 lowest() { |
313 |
| - return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); |
314 |
| - } |
315 |
| - static constexpr c10::BFloat16 max() { |
316 |
| - return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); |
317 |
| - } |
318 |
| - static constexpr c10::BFloat16 epsilon() { |
319 |
| - return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); |
320 |
| - } |
321 |
| - static constexpr c10::BFloat16 round_error() { |
322 |
| - return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); |
323 |
| - } |
324 |
| - static constexpr c10::BFloat16 infinity() { |
325 |
| - return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); |
326 |
| - } |
327 |
| - static constexpr c10::BFloat16 quiet_NaN() { |
328 |
| - return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); |
329 |
| - } |
330 |
| - static constexpr c10::BFloat16 signaling_NaN() { |
331 |
| - return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); |
332 |
| - } |
333 |
| - static constexpr c10::BFloat16 denorm_min() { |
334 |
| - return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); |
335 |
| - } |
336 |
| -}; |
337 |
| - |
338 |
| -} // namespace std |
339 |
| - |
340 |
| -C10_CLANG_DIAGNOSTIC_POP() |
| 1 | +#include <torch/headeronly/util/BFloat16.h> |
0 commit comments