1616
1717// ================================================================================
1818// this file has been auto-generated, do not modify its contents!
19- // date: 2023-10-12 17:25:02.978518
20- // git hash: 25f9bb64a14ef5a93b356d6089becd7139a0141f
19+ // date: 2023-10-12 19:42:20.177310
20+ // git hash: 4824f9787b219562d394b19c74f701ff75d8fb56
2121// ================================================================================
2222
2323#ifndef KERNEL_FLOAT_MACROS_H
@@ -892,19 +892,27 @@ KERNEL_FLOAT_INLINE map_type<F, V> map(F fun, const V& input) {
892892 return result;
893893}
894894
895+ namespace detail {
896+ // Indicates that elements of type `T` offer less precision than floats, thus operations
897+ // on elements of type `T` can be performed by upcasting them to ` float`.
898+ template <typename T>
899+ struct allow_float_fallback {
900+ static constexpr bool value = false ;
901+ };
902+
903+ template <>
904+ struct allow_float_fallback <float > {
905+ static constexpr bool value = true ;
906+ };
907+ } // namespace detail
908+
895909enum struct RoundingMode { ANY, DOWN, UP, NEAREST, TOWARD_ZERO };
896910
897911namespace ops {
912+
898913template <typename T, typename R, RoundingMode m = RoundingMode::ANY, typename = void >
899914struct cast ;
900915
901- template <typename T, typename R>
902- struct cast <T, R, RoundingMode::ANY> {
903- KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
904- return R (input);
905- }
906- };
907-
908916template <typename T, RoundingMode m>
909917struct cast <T, T, m> {
910918 KERNEL_FLOAT_INLINE T operator ()(T input) noexcept {
@@ -918,6 +926,41 @@ struct cast<T, T, RoundingMode::ANY> {
918926 return input;
919927 }
920928};
929+
930+ template <typename T, typename R, typename = void >
931+ struct cast_float_fallback ;
932+
933+ template <typename T, typename R>
934+ struct cast <T, R, RoundingMode::ANY> {
935+ KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
936+ return cast_float_fallback<T, R> {}(input);
937+ }
938+ };
939+
940+ template <typename T, typename R, typename >
941+ struct cast_float_fallback {
942+ KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
943+ return R (input);
944+ }
945+ };
946+
947+ // clang-format off
948+ template <typename T, typename R>
949+ struct cast_float_fallback <
950+ T,
951+ R,
952+ enable_if_t <
953+ !is_same_type<T, float > &&
954+ !is_same_type<R, float > &&
955+ (detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value)
956+ >
957+ > {
958+ KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
959+ return cast<float , R> {}(cast<T, float > {}(input));
960+ }
961+ };
962+ // clang-format on
963+
921964} // namespace ops
922965
923966/* *
@@ -973,20 +1016,6 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(negate, -, -input)
9731016KERNEL_FLOAT_DEFINE_UNARY_OP (bit_not, ~, ~input)
9741017KERNEL_FLOAT_DEFINE_UNARY_OP (logical_not, !, (ops::cast<bool , T> {}(!ops::cast<T, bool > {}(input))))
9751018
976- namespace detail {
977- // Indicates that elements of type `T` offer less precision than floats, thus operations
978- // on elements of type `T` can be performed by upcasting them to ` float`.
979- template <typename T>
980- struct allow_float_fallback {
981- static constexpr bool value = false ;
982- };
983-
984- template <>
985- struct allow_float_fallback <float > {
986- static constexpr bool value = true ;
987- };
988- } // namespace detail
989-
9901019#define KERNEL_FLOAT_DEFINE_UNARY_MATH (NAME ) \
9911020 namespace ops { \
9921021 template <typename T, typename = void > \
@@ -1460,7 +1489,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
14601489 template <typename T> \
14611490 struct NAME { \
14621491 KERNEL_FLOAT_INLINE T operator ()(T left, T right) { \
1463- return T (EXPR); \
1492+ return ops::cast< decltype (EXPR), T> {}(EXPR); \
14641493 } \
14651494 }; \
14661495 } \
@@ -3497,15 +3526,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
34973526 }; \
34983527 }
34993528#else
3500- #define KERNEL_FLOAT_FP16_BINARY_FUN (NAME, FUN1, FUN2 ) \
3501- namespace ops { \
3502- template <> \
3503- struct NAME <__half> { \
3504- KERNEL_FLOAT_INLINE __half operator ()(__half left, __half right) const { \
3505- return __half (ops::NAME<float > {}(float (left), float (right))); \
3506- } \
3507- }; \
3508- }
3529+ #define KERNEL_FLOAT_FP16_BINARY_FUN (NAME, FUN1, FUN2 )
35093530#endif
35103531
35113532KERNEL_FLOAT_FP16_BINARY_FUN (add, __hadd, __hadd2)
@@ -3793,16 +3814,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
37933814 }; \
37943815 }
37953816#else
3796- #define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 ) \
3797- namespace ops { \
3798- template <> \
3799- struct NAME <__nv_bfloat16> { \
3800- KERNEL_FLOAT_INLINE __nv_bfloat16 \
3801- operator ()(__nv_bfloat16 left, __nv_bfloat16 right) const { \
3802- return __nv_bfloat16 (ops::NAME<float > {}(float (left), float (right))); \
3803- } \
3804- }; \
3805- }
3817+ #define KERNEL_FLOAT_BF16_BINARY_FUN (NAME, FUN1, FUN2 )
38063818#endif
38073819
38083820KERNEL_FLOAT_BF16_BINARY_FUN (add, __hadd, __hadd2)
@@ -3822,20 +3834,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2)
38223834KERNEL_FLOAT_BF16_BINARY_FUN (greater_equal, __hge, __hgt2)
38233835
38243836namespace ops {
3825- template <typename T>
3826- struct cast <T, __nv_bfloat16> {
3827- KERNEL_FLOAT_INLINE __nv_bfloat16 operator ()(T input) {
3828- return __float2bfloat16 (ops::cast<T, float > {}(input));
3829- };
3830- };
3831-
3832- template <typename T>
3833- struct cast <__nv_bfloat16, T> {
3834- KERNEL_FLOAT_INLINE T operator ()(__nv_bfloat16 input) {
3835- return ops::cast<float , T> {}(__bfloat162float (input));
3836- };
3837- };
3838-
38393837template <>
38403838struct cast <double , __nv_bfloat16> {
38413839 KERNEL_FLOAT_INLINE __nv_bfloat16 operator ()(double input) {
@@ -3957,10 +3955,6 @@ struct dot_impl<__nv_bfloat16, N> {
39573955
39583956
39593957namespace kernel_float {
3960- #if KERNEL_FLOAT_CUDA_ARCH >= 800
3961- KERNEL_FLOAT_BF16_CAST (__half, __float2bfloat16(input), __bfloat162float(input));
3962- #endif
3963-
39643958template <>
39653959struct promote_type <__nv_bfloat16, __half> {
39663960 using type = float ;
@@ -4007,6 +4001,39 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
40074001 static constexpr bool value = true ;
40084002};
40094003} // namespace detail
4004+
4005+ #define KERNEL_FLOAT_FP8_CAST (T ) \
4006+ namespace ops { \
4007+ template <> \
4008+ struct cast <T, __nv_fp8_e4m3> { \
4009+ KERNEL_FLOAT_INLINE __nv_fp8_e4m3 operator ()(T v) const { \
4010+ return __nv_fp8_e4m3 (v); \
4011+ } \
4012+ }; \
4013+ \
4014+ template <> \
4015+ struct cast <__nv_fp8_e4m3, T> { \
4016+ KERNEL_FLOAT_INLINE T operator ()(__nv_fp8_e4m3 v) const { \
4017+ return T (v); \
4018+ } \
4019+ }; \
4020+ \
4021+ template <> \
4022+ struct cast <T, __nv_fp8_e5m2> { \
4023+ KERNEL_FLOAT_INLINE __nv_fp8_e5m2 operator ()(T v) const { \
4024+ return __nv_fp8_e5m2 (v); \
4025+ } \
4026+ }; \
4027+ \
4028+ template <> \
4029+ struct cast <__nv_fp8_e5m2, T> { \
4030+ KERNEL_FLOAT_INLINE T operator ()(__nv_fp8_e5m2 v) const { \
4031+ return T (v); \
4032+ } \
4033+ }; \
4034+ }
4035+
4036+ KERNEL_FLOAT_FP8_CAST (double )
40104037} // namespace kernel_float
40114038
40124039#if KERNEL_FLOAT_FP16_AVAILABLE
@@ -4015,6 +4042,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
40154042namespace kernel_float {
40164043KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e4m3)
40174044KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e5m2)
4045+ KERNEL_FLOAT_FP8_CAST (__half)
40184046} // namespace kernel_float
40194047#endif // KERNEL_FLOAT_FP16_AVAILABLE
40204048
@@ -4024,6 +4052,7 @@ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
40244052namespace kernel_float {
40254053KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e4m3)
40264054KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e5m2)
4055+ KERNEL_FLOAT_FP8_CAST (__nv_bfloat16)
40274056} // namespace kernel_float
40284057#endif // KERNEL_FLOAT_BF16_AVAILABLE
40294058
0 commit comments