@@ -176,16 +176,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
176176 }; \
177177 }
178178#else
179- #define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 ) \
180- namespace ops { \
181- template <> \
182- struct NAME <__nv_bfloat16> { \
183- KERNEL_FLOAT_INLINE __nv_bfloat16 \
184- operator ()(__nv_bfloat16 left, __nv_bfloat16 right) const { \
185- return __nv_bfloat16 (ops::NAME<float > {}(float (left), float (right))); \
186- } \
187- }; \
188- }
179+ #define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 )
189180#endif
190181
191182KERNEL_FLOAT_BF16_BINARY_FUN (add, __hadd, __hadd2)
@@ -205,20 +196,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
205196KERNEL_FLOAT_BF16_BINARY_FUN (greater_equal, __hge, __hgt2)
206197
207198namespace ops {
208- template <typename T>
209- struct cast <T, __nv_bfloat16> {
210- KERNEL_FLOAT_INLINE __nv_bfloat16 operator ()(T input) {
211- return __float2bfloat16 (ops::cast<T, float > {}(input));
212- };
213- };
214-
215- template <typename T>
216- struct cast <__nv_bfloat16, T> {
217- KERNEL_FLOAT_INLINE T operator ()(__nv_bfloat16 input) {
218- return ops::cast<float , T> {}(__bfloat162float (input));
219- };
220- };
221-
222199template <>
223200struct cast <double , __nv_bfloat16> {
224201 KERNEL_FLOAT_INLINE __nv_bfloat16 operator ()(double input) {
@@ -340,10 +317,6 @@ struct dot_impl<__nv_bfloat16, N> {
340317#include " fp16.h"
341318
342319namespace kernel_float {
343- #if KERNEL_FLOAT_CUDA_ARCH >= 800
344- KERNEL_FLOAT_BF16_CAST (__half, __float2bfloat16(input), __bfloat162float(input));
345- #endif
346-
347320template <>
348321struct promote_type <__nv_bfloat16, __half> {
349322 using type = float ;
0 commit comments