44#include " macros.h"
55
66#if KERNEL_FLOAT_BF16_AVAILABLE
7+ // #define CUDA_NO_BFLOAT16 (1)
8+ // #define __CUDA_NO_BFLOAT16_OPERATORS__ (1)
9+ // #define __CUDA_NO_BFLOAT162_OPERATORS__ (1)
10+ // #define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1)
11+
712#if KERNEL_FLOAT_IS_CUDA
813#include < cuda_bf16.h>
914#elif KERNEL_FLOAT_IS_HIP
@@ -76,21 +81,24 @@ struct allow_float_fallback<__bfloat16> {
7681 }; \
7782 }
7883
79- KERNEL_FLOAT_BF16_UNARY_FUN (abs, ::__habs, ::__habs2)
80- KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
81- KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
84+ KERNEL_FLOAT_BF16_UNARY_FUN (sin, ::hsin, ::h2sin)
8285KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos)
86+
8387KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp)
8488KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10)
85- KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor)
8689KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log)
8790KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2)
88- KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
89- KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
90- KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
91+
9192KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt)
92- KERNEL_FLOAT_BF16_UNARY_FUN(trunc , ::htrunc , ::h2trunc )
93+ KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt , ::hrsqrt , ::h2rsqrt )
9394KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
95+
96+ KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2)
97+ KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor)
98+ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil)
99+ KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint)
100+ KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
101+ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2)
94102#endif
95103
96104#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
@@ -99,7 +107,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
99107 template <> \
100108 struct NAME <__bfloat16> { \
101109 KERNEL_FLOAT_INLINE __bfloat16 operator ()(__bfloat16 left, __bfloat16 right) const { \
102- return FUN1 (left, right); \
110+ return ops::cast< decltype ( FUN1 (left, right)), __bfloat16> {}( FUN1 (left, right)); \
103111 } \
104112 }; \
105113 } \
@@ -159,29 +167,6 @@ struct apply_impl<ops::fma<__bfloat16>, 2, __bfloat16, __bfloat16, __bfloat16, _
159167} // namespace detail
160168#endif
161169
162- namespace ops {
163- template <>
164- struct cast <double , __bfloat16> {
165- KERNEL_FLOAT_INLINE __bfloat16 operator ()(double input) {
166- return __double2bfloat16 (input);
167- };
168- };
169-
170- template <>
171- struct cast <float , __bfloat16> {
172- KERNEL_FLOAT_INLINE __bfloat16 operator ()(float input) {
173- return __float2bfloat16 (input);
174- };
175- };
176-
177- template <>
178- struct cast <__bfloat16, float > {
179- KERNEL_FLOAT_INLINE float operator ()(__bfloat16 input) {
180- return __bfloat162float (input);
181- };
182- };
183- } // namespace ops
184-
185170#define KERNEL_FLOAT_BF16_CAST (T, TO_HALF, FROM_HALF ) \
186171 namespace ops { \
187172 template <> \
@@ -198,31 +183,33 @@ struct cast<__bfloat16, float> {
198183 }; \
199184 }
200185
186+ KERNEL_FLOAT_BF16_CAST (float , __float2bfloat16(input), __bfloat162float(input))
187+ KERNEL_FLOAT_BF16_CAST (double , __double2bfloat16(input), __bfloat162float(input))
188+
201189#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
202190// clang-format off
203191// there are no official char casts. Instead, cast to int and then to char
204192KERNEL_FLOAT_BF16_CAST (char , __int2bfloat16_rn(input), (char )__bfloat162int_rz(input));
205193KERNEL_FLOAT_BF16_CAST (signed char , __int2bfloat16_rn(input), (signed char )__bfloat162int_rz(input));
206194KERNEL_FLOAT_BF16_CAST (unsigned char , __int2bfloat16_rn(input), (unsigned char )__bfloat162int_rz(input));
207195
208- KERNEL_FLOAT_BF16_CAST (signed short , __bfloat162short_rz (input), __short2bfloat16_rn (input));
209- KERNEL_FLOAT_BF16_CAST (signed int , __bfloat162int_rz (input), __int2bfloat16_rn (input));
196+ KERNEL_FLOAT_BF16_CAST (signed short , __short2bfloat16_rn (input), __bfloat162short_rz (input));
197+ KERNEL_FLOAT_BF16_CAST (signed int , __int2bfloat16_rn (input), __bfloat162int_rz (input));
210198KERNEL_FLOAT_BF16_CAST (signed long , __ll2bfloat16_rn(input), (signed long )(__bfloat162ll_rz(input)));
211199KERNEL_FLOAT_BF16_CAST (signed long long , __ll2bfloat16_rn(input), __bfloat162ll_rz(input));
212200
213- KERNEL_FLOAT_BF16_CAST (unsigned short , __bfloat162ushort_rz (input), __ushort2bfloat16_rn (input));
214- KERNEL_FLOAT_BF16_CAST (unsigned int , __bfloat162uint_rz (input), __uint2bfloat16_rn (input));
201+ KERNEL_FLOAT_BF16_CAST (unsigned short , __ushort2bfloat16_rn (input), __bfloat162ushort_rz (input));
202+ KERNEL_FLOAT_BF16_CAST (unsigned int , __uint2bfloat16_rn (input), __bfloat162uint_rz (input));
215203KERNEL_FLOAT_BF16_CAST (unsigned long , __ull2bfloat16_rn(input), (unsigned long )(__bfloat162ull_rz(input)));
216204KERNEL_FLOAT_BF16_CAST (unsigned long long , __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
217205// clang-format on
218206#endif
219207
220208#if KERNEL_FLOAT_IS_CUDA
221- KERNEL_FLOAT_BF16_CAST (
222- bool ,
223- __nv_bfloat16_raw {input ? (unsigned short )0 : (unsigned short )0x3C00 },
224- (__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
225-
209+ // KERNEL_FLOAT_BF16_CAST(
210+ // bool,
211+ // __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
212+ // (__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
226213#elif KERNEL_FLOAT_IS_HIP
227214KERNEL_FLOAT_BF16_CAST (
228215 bool ,
0 commit comments