@@ -64,7 +64,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
6464#define KERNEL_FLOAT_FP8_CAST2 (T, FP8_TY, FP8_INTERP ) \
6565 namespace detail { \
6666 template <> \
67- struct apply_impl <ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
67+ struct apply_impl <accurate_policy, ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
6868 KERNEL_FLOAT_INLINE static void call (ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
6969 __half2_raw x; \
7070 memcpy (&x, v, 2 * sizeof (T)); \
@@ -73,7 +73,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
7373 } \
7474 }; \
7575 template <> \
76- struct apply_impl <ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
76+ struct apply_impl <accurate_policy, ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
7777 KERNEL_FLOAT_INLINE static void call (ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
7878 __nv_fp8x2_storage_t x; \
7979 memcpy (&x, v, 2 * sizeof (FP8_TY)); \
@@ -91,12 +91,12 @@ KERNEL_FLOAT_FP8_CAST(double)
9191#include " fp16.h"
9292
9393namespace kernel_float {
94- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e4m3)
95- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e5m2)
94+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e4m3)
95+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e5m2)
9696
97- KERNEL_FLOAT_FP8_CAST (__half )
98- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e4m3, __NV_E4M3)
99- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e5m2, __NV_E5M2)
97+ KERNEL_FLOAT_FP8_CAST (half_t )
98+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e4m3, __NV_E4M3)
99+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e5m2, __NV_E5M2)
100100
101101} // namespace kernel_float
102102#endif // KERNEL_FLOAT_FP16_AVAILABLE
@@ -105,12 +105,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
105105#include " bf16.h"
106106
107107namespace kernel_float {
108- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e4m3)
109- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e5m2)
108+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e4m3)
109+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e5m2)
110110
111- KERNEL_FLOAT_FP8_CAST (__nv_bfloat16 )
112- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e4m3, __NV_E4M3)
113- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e5m2, __NV_E5M2)
111+ KERNEL_FLOAT_FP8_CAST (bfloat16_t )
112+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e4m3, __NV_E4M3)
113+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e5m2, __NV_E5M2)
114114} // namespace kernel_float
115115#endif // KERNEL_FLOAT_BF16_AVAILABLE
116116
0 commit comments