88
99#pragma once
1010
11- #include < sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
12- #include < sycl/detail/memcpy.hpp> // sycl::detail::memcpy
11+ #include < sycl/bit_cast.hpp> // for sycl::bit_cast
12+ #include < sycl/builtins.hpp> // for ceil, cos, exp, exp10, exp2
13+ #include < sycl/detail/memcpy.hpp> // sycl::detail::memcpy
1314#include < sycl/detail/vector_convert.hpp>
14- #include < sycl/ext/oneapi/bfloat16.hpp> // for bfloat16, bfloat16ToBits
15+ #include < sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
1516#include < sycl/marray.hpp> // for marray
1617
1718#include < cstring> // for size_t
@@ -46,7 +47,7 @@ constexpr int num_elements_v = sycl::detail::num_elements<T>::value;
4647// significand has non-zero bits.
4748template <typename T>
4849std::enable_if_t <std::is_same_v<T, bfloat16>, bool > isnan (T x) {
49- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
50+ uint16_t XBits = bit_cast< uint16_t > (x);
5051 return (((XBits & 0x7F80 ) == 0x7F80 ) && (XBits & 0x7F )) ? true : false ;
5152}
5253
@@ -90,15 +91,15 @@ template <typename T>
9091std::enable_if_t <std::is_same_v<T, bfloat16>, T> fabs (T x) {
9192#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
9293 (__SYCL_CUDA_ARCH__ >= 800 )
93- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
94- return oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
94+ uint16_t XBits = bit_cast< uint16_t > (x);
95+ return bit_cast<bfloat16> (__clc_fabs (XBits));
9596#else
9697 if (!isnan (x)) {
97- const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000 ;
98- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
98+ constexpr uint16_t SignMask = 0x8000 ;
99+ uint16_t XBits = bit_cast< uint16_t > (x);
99100 x = ((XBits & SignMask) == SignMask)
100- ? oneapi::detail::bitsToBfloat16 (XBits & ~SignMask)
101- : x ;
101+ ? bit_cast<bfloat16, uint16_t > (XBits & ~SignMask)
102+ : bit_cast<bfloat16>(x) ;
102103 }
103104 return x;
104105#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -116,9 +117,8 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
116117 }
117118
118119 if (N % 2 ) {
119- oneapi::detail::Bfloat16StorageT XBits =
120- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
121- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
120+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
121+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fabs (XBits));
122122 }
123123#else
124124 for (size_t i = 0 ; i < N; i++) {
@@ -154,25 +154,22 @@ template <typename T>
154154std::enable_if_t <std::is_same_v<T, bfloat16>, T> fmin (T x, T y) {
155155#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
156156 (__SYCL_CUDA_ARCH__ >= 800 )
157- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
158- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
159- return oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
157+ uint16_t XBits = bit_cast< uint16_t > (x);
158+ uint16_t YBits = bit_cast< uint16_t > (y);
159+ return bit_cast<bfloat16> (__clc_fmin (XBits, YBits));
160160#else
161- static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
161+ constexpr uint16_t CanonicalNan = 0x7FC0 ;
162162 if (isnan (x) && isnan (y))
163- return oneapi::detail::bitsToBfloat16 (CanonicalNan);
163+ return bit_cast<bfloat16> (CanonicalNan);
164164
165165 if (isnan (x))
166166 return y;
167167 if (isnan (y))
168168 return x;
169- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
170- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
171- if (((XBits | YBits) ==
172- static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
173- !(XBits & YBits))
174- return oneapi::detail::bitsToBfloat16 (
175- static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 ));
169+ uint16_t XBits = bit_cast<uint16_t >(x);
170+ uint16_t YBits = bit_cast<uint16_t >(y);
171+ if (((XBits | YBits) == static_cast <uint16_t >(0x8000 )) && !(XBits & YBits))
172+ return bit_cast<bfloat16>(static_cast <uint16_t >(0x8000 ));
176173
177174 return (x < y) ? x : y;
178175#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -192,11 +189,9 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
192189 }
193190
194191 if (N % 2 ) {
195- oneapi::detail::Bfloat16StorageT XBits =
196- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
197- oneapi::detail::Bfloat16StorageT YBits =
198- oneapi::detail::bfloat16ToBits (y[N - 1 ]);
199- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
192+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
193+ uint16_t YBits = bit_cast<uint16_t >(y[N - 1 ]);
194+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fmin (XBits, YBits));
200195 }
201196#else
202197 for (size_t i = 0 ; i < N; i++) {
@@ -237,24 +232,22 @@ template <typename T>
237232std::enable_if_t <std::is_same_v<T, bfloat16>, T> fmax (T x, T y) {
238233#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
239234 (__SYCL_CUDA_ARCH__ >= 800 )
240- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
241- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
242- return oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
235+ uint16_t XBits = bit_cast< uint16_t > (x);
236+ uint16_t YBits = bit_cast< uint16_t > (y);
237+ return bit_cast<bfloat16> (__clc_fmax (XBits, YBits));
243238#else
244- static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
239+ constexpr uint16_t CanonicalNan = 0x7FC0 ;
245240 if (isnan (x) && isnan (y))
246- return oneapi::detail::bitsToBfloat16 (CanonicalNan);
241+ return bit_cast<bfloat16> (CanonicalNan);
247242
248243 if (isnan (x))
249244 return y;
250245 if (isnan (y))
251246 return x;
252- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
253- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
254- if (((XBits | YBits) ==
255- static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
256- !(XBits & YBits))
257- return oneapi::detail::bitsToBfloat16 (0 );
247+ uint16_t XBits = bit_cast<uint16_t >(x);
248+ uint16_t YBits = bit_cast<uint16_t >(y);
249+ if (((XBits | YBits) == static_cast <uint16_t >(0x8000 )) && !(XBits & YBits))
250+ return bit_cast<bfloat16, uint16_t >(0 );
258251
259252 return (x > y) ? x : y;
260253#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -274,11 +267,9 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
274267 }
275268
276269 if (N % 2 ) {
277- oneapi::detail::Bfloat16StorageT XBits =
278- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
279- oneapi::detail::Bfloat16StorageT YBits =
280- oneapi::detail::bfloat16ToBits (y[N - 1 ]);
281- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
270+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
271+ uint16_t YBits = bit_cast<uint16_t >(y[N - 1 ]);
272+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fmax (XBits, YBits));
282273 }
283274#else
284275 for (size_t i = 0 ; i < N; i++) {
@@ -319,10 +310,10 @@ template <typename T>
319310std::enable_if_t <std::is_same_v<T, bfloat16>, T> fma (T x, T y, T z) {
320311#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
321312 (__SYCL_CUDA_ARCH__ >= 800 )
322- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
323- oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
324- oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits (z);
325- return oneapi::detail::bitsToBfloat16 (__clc_fma (XBits, YBits, ZBits));
313+ uint16_t XBits = bit_cast< uint16_t > (x);
314+ uint16_t YBits = bit_cast< uint16_t > (y);
315+ uint16_t ZBits = bit_cast< uint16_t > (z);
316+ return bit_cast<bfloat16> (__clc_fma (XBits, YBits, ZBits));
326317#else
327318 return sycl::ext::oneapi::bfloat16{sycl::fma (float {x}, float {y}, float {z})};
328319#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) &&
@@ -344,13 +335,10 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
344335 }
345336
346337 if (N % 2 ) {
347- oneapi::detail::Bfloat16StorageT XBits =
348- oneapi::detail::bfloat16ToBits (x[N - 1 ]);
349- oneapi::detail::Bfloat16StorageT YBits =
350- oneapi::detail::bfloat16ToBits (y[N - 1 ]);
351- oneapi::detail::Bfloat16StorageT ZBits =
352- oneapi::detail::bfloat16ToBits (z[N - 1 ]);
353- res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fma (XBits, YBits, ZBits));
338+ uint16_t XBits = bit_cast<uint16_t >(x[N - 1 ]);
339+ uint16_t YBits = bit_cast<uint16_t >(y[N - 1 ]);
340+ uint16_t ZBits = bit_cast<uint16_t >(z[N - 1 ]);
341+ res[N - 1 ] = bit_cast<bfloat16>(__clc_fma (XBits, YBits, ZBits));
354342 }
355343#else
356344 for (size_t i = 0 ; i < N; i++) {
0 commit comments