|
1 | | -#pragma once |
2 | | - |
3 | | -#include <c10/macros/Macros.h> |
4 | | -#include <c10/util/bit_cast.h> |
5 | | - |
6 | | -#include <limits> |
7 | | - |
8 | | -C10_CLANG_DIAGNOSTIC_PUSH() |
9 | | -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") |
10 | | -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") |
11 | | -#endif |
12 | | - |
13 | | -#if defined(CL_SYCL_LANGUAGE_VERSION) |
14 | | -#include <CL/sycl.hpp> // for SYCL 1.2.1 |
15 | | -#elif defined(SYCL_LANGUAGE_VERSION) |
16 | | -#include <sycl/sycl.hpp> // for SYCL 2020 |
17 | | -#endif |
18 | | - |
19 | | -namespace c10 { |
20 | | - |
21 | | -/// Constructors |
22 | | -inline C10_HOST_DEVICE BFloat16::BFloat16(float value) |
23 | | - : |
24 | | -#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ |
25 | | - __CUDA_ARCH__ >= 800 |
26 | | - x(__bfloat16_as_ushort(__float2bfloat16(value))) |
27 | | -#elif defined(__SYCL_DEVICE_ONLY__) && \ |
28 | | - defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
29 | | - x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value))) |
30 | | -#else |
31 | | - // RNE by default |
32 | | - x(detail::round_to_nearest_even(value)) |
33 | | -#endif |
34 | | -{ |
35 | | -} |
36 | | - |
37 | | -/// Implicit conversions |
38 | | -inline C10_HOST_DEVICE BFloat16::operator float() const { |
39 | | -#if defined(__CUDACC__) && !defined(USE_ROCM) |
40 | | - return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x)); |
41 | | -#elif defined(__SYCL_DEVICE_ONLY__) && \ |
42 | | - defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
43 | | - return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x)); |
44 | | -#else |
45 | | - return detail::f32_from_bits(x); |
46 | | -#endif |
47 | | -} |
48 | | - |
49 | | -#if defined(__CUDACC__) && !defined(USE_ROCM) |
50 | | -inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { |
51 | | - x = *reinterpret_cast<const unsigned short*>(&value); |
52 | | -} |
53 | | -inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { |
54 | | - return *reinterpret_cast<const __nv_bfloat16*>(&x); |
55 | | -} |
56 | | -#endif |
57 | | - |
58 | | -#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) |
59 | | -inline C10_HOST_DEVICE BFloat16::BFloat16( |
60 | | - const sycl::ext::oneapi::bfloat16& value) { |
61 | | - x = *reinterpret_cast<const unsigned short*>(&value); |
62 | | -} |
63 | | -inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { |
64 | | - return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x); |
65 | | -} |
66 | | -#endif |
67 | | - |
68 | | -// CUDA intrinsics |
69 | | - |
70 | | -#if defined(__CUDACC__) || defined(__HIPCC__) |
71 | | -inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { |
72 | | -#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
73 | | - return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr)); |
74 | | -#else |
75 | | - return *ptr; |
76 | | -#endif |
77 | | -} |
78 | | -#endif |
79 | | - |
80 | | -/// Arithmetic |
81 | | - |
82 | | -inline C10_HOST_DEVICE BFloat16 |
83 | | -operator+(const BFloat16& a, const BFloat16& b) { |
84 | | - return static_cast<float>(a) + static_cast<float>(b); |
85 | | -} |
86 | | - |
87 | | -inline C10_HOST_DEVICE BFloat16 |
88 | | -operator-(const BFloat16& a, const BFloat16& b) { |
89 | | - return static_cast<float>(a) - static_cast<float>(b); |
90 | | -} |
91 | | - |
92 | | -inline C10_HOST_DEVICE BFloat16 |
93 | | -operator*(const BFloat16& a, const BFloat16& b) { |
94 | | - return static_cast<float>(a) * static_cast<float>(b); |
95 | | -} |
96 | | - |
97 | | -inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) |
98 | | - __ubsan_ignore_float_divide_by_zero__ { |
99 | | - return static_cast<float>(a) / static_cast<float>(b); |
100 | | -} |
101 | | - |
102 | | -inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { |
103 | | - return -static_cast<float>(a); |
104 | | -} |
105 | | - |
106 | | -inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) { |
107 | | - a = a + b; |
108 | | - return a; |
109 | | -} |
110 | | - |
111 | | -inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) { |
112 | | - a = a - b; |
113 | | - return a; |
114 | | -} |
115 | | - |
116 | | -inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) { |
117 | | - a = a * b; |
118 | | - return a; |
119 | | -} |
120 | | - |
121 | | -inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { |
122 | | - a = a / b; |
123 | | - return a; |
124 | | -} |
125 | | - |
126 | | -inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { |
127 | | - a.x = a.x | b.x; |
128 | | - return a; |
129 | | -} |
130 | | - |
131 | | -inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { |
132 | | - a.x = a.x ^ b.x; |
133 | | - return a; |
134 | | -} |
135 | | - |
136 | | -inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { |
137 | | - a.x = a.x & b.x; |
138 | | - return a; |
139 | | -} |
140 | | - |
141 | | -/// Arithmetic with floats |
142 | | - |
143 | | -inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { |
144 | | - return static_cast<float>(a) + b; |
145 | | -} |
146 | | -inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) { |
147 | | - return static_cast<float>(a) - b; |
148 | | -} |
149 | | -inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) { |
150 | | - return static_cast<float>(a) * b; |
151 | | -} |
152 | | -inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) { |
153 | | - return static_cast<float>(a) / b; |
154 | | -} |
155 | | - |
156 | | -inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) { |
157 | | - return a + static_cast<float>(b); |
158 | | -} |
159 | | -inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) { |
160 | | - return a - static_cast<float>(b); |
161 | | -} |
162 | | -inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) { |
163 | | - return a * static_cast<float>(b); |
164 | | -} |
165 | | -inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) { |
166 | | - return a / static_cast<float>(b); |
167 | | -} |
168 | | - |
169 | | -inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { |
170 | | - return a += static_cast<float>(b); |
171 | | -} |
172 | | -inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { |
173 | | - return a -= static_cast<float>(b); |
174 | | -} |
175 | | -inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { |
176 | | - return a *= static_cast<float>(b); |
177 | | -} |
178 | | -inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { |
179 | | - return a /= static_cast<float>(b); |
180 | | -} |
181 | | - |
182 | | -/// Arithmetic with doubles |
183 | | - |
184 | | -inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) { |
185 | | - return static_cast<double>(a) + b; |
186 | | -} |
187 | | -inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) { |
188 | | - return static_cast<double>(a) - b; |
189 | | -} |
190 | | -inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) { |
191 | | - return static_cast<double>(a) * b; |
192 | | -} |
193 | | -inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) { |
194 | | - return static_cast<double>(a) / b; |
195 | | -} |
196 | | - |
197 | | -inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) { |
198 | | - return a + static_cast<double>(b); |
199 | | -} |
200 | | -inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) { |
201 | | - return a - static_cast<double>(b); |
202 | | -} |
203 | | -inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) { |
204 | | - return a * static_cast<double>(b); |
205 | | -} |
206 | | -inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) { |
207 | | - return a / static_cast<double>(b); |
208 | | -} |
209 | | - |
210 | | -/// Arithmetic with ints |
211 | | - |
212 | | -inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { |
213 | | - return a + static_cast<BFloat16>(b); |
214 | | -} |
215 | | -inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { |
216 | | - return a - static_cast<BFloat16>(b); |
217 | | -} |
218 | | -inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { |
219 | | - return a * static_cast<BFloat16>(b); |
220 | | -} |
221 | | -inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { |
222 | | - return a / static_cast<BFloat16>(b); |
223 | | -} |
224 | | - |
225 | | -inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { |
226 | | - return static_cast<BFloat16>(a) + b; |
227 | | -} |
228 | | -inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { |
229 | | - return static_cast<BFloat16>(a) - b; |
230 | | -} |
231 | | -inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { |
232 | | - return static_cast<BFloat16>(a) * b; |
233 | | -} |
234 | | -inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { |
235 | | - return static_cast<BFloat16>(a) / b; |
236 | | -} |
237 | | - |
238 | | -//// Arithmetic with int64_t |
239 | | - |
240 | | -inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { |
241 | | - return a + static_cast<BFloat16>(b); |
242 | | -} |
243 | | -inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { |
244 | | - return a - static_cast<BFloat16>(b); |
245 | | -} |
246 | | -inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { |
247 | | - return a * static_cast<BFloat16>(b); |
248 | | -} |
249 | | -inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { |
250 | | - return a / static_cast<BFloat16>(b); |
251 | | -} |
252 | | - |
253 | | -inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { |
254 | | - return static_cast<BFloat16>(a) + b; |
255 | | -} |
256 | | -inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { |
257 | | - return static_cast<BFloat16>(a) - b; |
258 | | -} |
259 | | -inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { |
260 | | - return static_cast<BFloat16>(a) * b; |
261 | | -} |
262 | | -inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { |
263 | | - return static_cast<BFloat16>(a) / b; |
264 | | -} |
265 | | - |
266 | | -// Overloading < and > operators, because std::max and std::min use them. |
267 | | - |
268 | | -inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { |
269 | | - return float(lhs) > float(rhs); |
270 | | -} |
271 | | - |
272 | | -inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { |
273 | | - return float(lhs) < float(rhs); |
274 | | -} |
275 | | - |
276 | | -} // namespace c10 |
277 | | - |
278 | | -namespace std { |
279 | | - |
280 | | -template <> |
281 | | -class numeric_limits<c10::BFloat16> { |
282 | | - public: |
283 | | - static constexpr bool is_signed = true; |
284 | | - static constexpr bool is_specialized = true; |
285 | | - static constexpr bool is_integer = false; |
286 | | - static constexpr bool is_exact = false; |
287 | | - static constexpr bool has_infinity = true; |
288 | | - static constexpr bool has_quiet_NaN = true; |
289 | | - static constexpr bool has_signaling_NaN = true; |
290 | | - static constexpr auto has_denorm = numeric_limits<float>::has_denorm; |
291 | | - static constexpr auto has_denorm_loss = |
292 | | - numeric_limits<float>::has_denorm_loss; |
293 | | - static constexpr auto round_style = numeric_limits<float>::round_style; |
294 | | - static constexpr bool is_iec559 = false; |
295 | | - static constexpr bool is_bounded = true; |
296 | | - static constexpr bool is_modulo = false; |
297 | | - static constexpr int digits = 8; |
298 | | - static constexpr int digits10 = 2; |
299 | | - static constexpr int max_digits10 = 4; |
300 | | - static constexpr int radix = 2; |
301 | | - static constexpr int min_exponent = -125; |
302 | | - static constexpr int min_exponent10 = -37; |
303 | | - static constexpr int max_exponent = 128; |
304 | | - static constexpr int max_exponent10 = 38; |
305 | | - static constexpr auto traps = numeric_limits<float>::traps; |
306 | | - static constexpr auto tinyness_before = |
307 | | - numeric_limits<float>::tinyness_before; |
308 | | - |
309 | | - static constexpr c10::BFloat16 min() { |
310 | | - return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); |
311 | | - } |
312 | | - static constexpr c10::BFloat16 lowest() { |
313 | | - return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); |
314 | | - } |
315 | | - static constexpr c10::BFloat16 max() { |
316 | | - return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); |
317 | | - } |
318 | | - static constexpr c10::BFloat16 epsilon() { |
319 | | - return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); |
320 | | - } |
321 | | - static constexpr c10::BFloat16 round_error() { |
322 | | - return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); |
323 | | - } |
324 | | - static constexpr c10::BFloat16 infinity() { |
325 | | - return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); |
326 | | - } |
327 | | - static constexpr c10::BFloat16 quiet_NaN() { |
328 | | - return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); |
329 | | - } |
330 | | - static constexpr c10::BFloat16 signaling_NaN() { |
331 | | - return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); |
332 | | - } |
333 | | - static constexpr c10::BFloat16 denorm_min() { |
334 | | - return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); |
335 | | - } |
336 | | -}; |
337 | | - |
338 | | -} // namespace std |
339 | | - |
340 | | -C10_CLANG_DIAGNOSTIC_POP() |
| 1 | +#include <torch/headeronly/util/BFloat16.h> |
0 commit comments