1616
1717// ================================================================================
1818// this file has been auto-generated, do not modify its contents!
19- // date: 2024-11-18 13:50:24.614671
20- // git hash: f89cf98f79e78ab6013063dea4b4b516ce163855
19+ // date: 2024-11-18 16:57:58.817191
20+ // git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
2121// ================================================================================
2222
2323#ifndef KERNEL_FLOAT_MACROS_H
@@ -824,31 +824,53 @@ using default_policy = KERNEL_FLOAT_POLICY;
824824
825825namespace detail {
826826
827+ //
827828template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
828- struct apply_base_impl {
829+ struct apply_fallback_impl {
829830 KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
830- #pragma unroll
831- for (size_t i = 0 ; i < N; i++) {
832- output[i] = fun (args[i]...);
833- }
831+ static_assert (N > 0 , " operation not implemented" );
834832 }
835833};
836834
835+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
836+ struct apply_base_impl : apply_fallback_impl<Policy, F, N, Output, Args...> {};
837+
837838template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
838839struct apply_impl : apply_base_impl<Policy, F, N, Output, Args...> {};
839840
841+ // `fast_policy` falls back to `accurate_policy`
840842template <typename F, size_t N, typename Output, typename ... Args>
841- struct apply_base_impl <fast_policy, F, N, Output, Args...>:
843+ struct apply_fallback_impl <fast_policy, F, N, Output, Args...>:
842844 apply_impl<accurate_policy, F, N, Output, Args...> {};
843845
846+ // `approx_policy` falls back to `fast_policy`
844847template <typename F, size_t N, typename Output, typename ... Args>
845- struct apply_base_impl <approx_policy, F, N, Output, Args...>:
848+ struct apply_fallback_impl <approx_policy, F, N, Output, Args...>:
846849 apply_impl<fast_policy, F, N, Output, Args...> {};
847850
851+ // `approx_level_policy` falls back to `approx_policy`
848852template <int Level, typename F, size_t N, typename Output, typename ... Args>
849- struct apply_base_impl <approx_level_policy<Level>, F, N, Output, Args...>:
853+ struct apply_fallback_impl <approx_level_policy<Level>, F, N, Output, Args...>:
850854 apply_impl<approx_policy, F, N, Output, Args...> {};
851855
856+ template <typename F, typename Output, typename ... Args>
857+ struct invoke_impl {
858+ KERNEL_FLOAT_INLINE static Output call (F fun, Args... args) {
859+ return fun (args...);
860+ }
861+ };
862+
863+ // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
864+ template <typename F, size_t N, typename Output, typename ... Args>
865+ struct apply_impl <accurate_policy, F, N, Output, Args...> {
866+ KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
867+ #pragma unroll
868+ for (size_t i = 0 ; i < N; i++) {
869+ output[i] = invoke_impl<F, Output, Args...>::call (fun, args[i]...);
870+ }
871+ }
872+ };
873+
852874template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
853875struct map_impl {
854876 static constexpr size_t packet_size = preferred_vector_size<Output>::value;
@@ -1949,7 +1971,7 @@ struct multiply<bool> {
19491971
19501972namespace detail {
19511973template <typename Policy, typename T, size_t N>
1952- struct apply_impl <Policy, ops::divide<T>, N, T, T, T> {
1974+ struct apply_base_impl <Policy, ops::divide<T>, N, T, T, T> {
19531975 KERNEL_FLOAT_INLINE static void call (ops::divide<T>, T* result, const T* lhs, const T* rhs) {
19541976 T rhs_rcp[N];
19551977
@@ -1959,10 +1981,6 @@ struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
19591981 }
19601982};
19611983
1962- template <typename T, size_t N>
1963- struct apply_impl <accurate_policy, ops::divide<T>, N, T, T, T>:
1964- apply_base_impl<accurate_policy, ops::divide<T>, N, T, T, T> {};
1965-
19661984#if KERNEL_FLOAT_IS_DEVICE
19671985template <>
19681986struct apply_impl <fast_policy, ops::divide<float >, 1 , float , float , float > {
@@ -1977,7 +1995,7 @@ struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
19771995namespace detail {
19781996// Override `pow` using `log2` and `exp2`
19791997template <typename Policy, typename T, size_t N>
1980- struct apply_impl <Policy, ops::pow<T>, N, T, T, T> {
1998+ struct apply_base_impl <Policy, ops::pow<T>, N, T, T, T> {
19811999 KERNEL_FLOAT_INLINE static void call (ops::divide<T>, T* result, const T* lhs, const T* rhs) {
19822000 T lhs_log[N];
19832001 T result_log[N];
@@ -1988,10 +2006,6 @@ struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
19882006 apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call ({}, result, result_log);
19892007 }
19902008};
1991-
1992- template <typename T, size_t N>
1993- struct apply_impl <accurate_policy, ops::pow<T>, N, T, T, T>:
1994- apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
19952009} // namespace detail
19962010
19972011template <typename L, typename R, typename T = promoted_vector_value_type<L, R>>
@@ -3218,13 +3232,13 @@ struct fma {
32183232} // namespace ops
32193233
32203234namespace detail {
3221- template <typename Policy, typename T, size_t N>
3222- struct apply_impl <Policy , ops::fma<T>, N, T, T, T, T> {
3235+ template <typename T, size_t N>
3236+ struct apply_impl <accurate_policy , ops::fma<T>, N, T, T, T, T> {
32233237 KERNEL_FLOAT_INLINE
32243238 static void call (ops::fma<T>, T* output, const T* a, const T* b, const T* c) {
32253239 T temp[N];
3226- apply_impl<Policy , ops::multiply<T>, N, T, T, T>::call ({}, temp, a, b);
3227- apply_impl<Policy , ops::add<T>, N, T, T, T>::call ({}, output, temp, c);
3240+ apply_impl<accurate_policy , ops::multiply<T>, N, T, T, T>::call ({}, temp, a, b);
3241+ apply_impl<accurate_policy , ops::add<T>, N, T, T, T>::call ({}, output, temp, c);
32283242 }
32293243};
32303244} // namespace detail
@@ -3992,9 +4006,6 @@ namespace kernel_float {
39924006using half_t = ::__half;
39934007using half2_t = ::__half2;
39944008
3995- using __half = void ;
3996- using __half2 = void ;
3997-
39984009template <>
39994010struct preferred_vector_size <half_t > {
40004011 static constexpr size_t value = 2 ;
@@ -4020,7 +4031,7 @@ template<>
40204031struct allow_float_fallback <half_t > {
40214032 static constexpr bool value = true ;
40224033};
4023- }; // namespace detail
4034+ } // namespace detail
40244035
40254036#if KERNEL_FLOAT_IS_DEVICE
40264037#define KERNEL_FLOAT_FP16_UNARY_FUN (NAME, FUN1, FUN2 ) \
@@ -4469,7 +4480,7 @@ namespace kernel_float {
44694480
44704481namespace approx {
44714482
4472- static_assert (sizeof (unsigned int ) * 8 == 32 , " invalid side of unsigned int" );
4483+ static_assert (sizeof (unsigned int ) * 8 == 32 , " invalid size of unsigned int" );
44734484using uint32_t = unsigned int ;
44744485
44754486template <typename T, typename U>
@@ -4806,11 +4817,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {
48064817
48074818template <int = 0 >
48084819KERNEL_FLOAT_DEVICE bfloat16x2_t exp (bfloat16x2_t arg) {
4809- static constexpr float SCALE = 1 .44272065994f / 256 .0f ;
4820+ static constexpr float SCALE = 1.44272065994 / 256.0 ;
48104821 static constexpr float OFFSET = 382.4958400542335 ;
4822+ static constexpr float MINIMUM = 382 ;
48114823
4812- auto a = fmaf (bfloat16x2_tfloat (arg.x ), SCALE, OFFSET);
4813- auto b = fmaf (bfloat16x2_tfloat (arg.y ), SCALE, OFFSET);
4824+ float a = fmaxf ( fmaf (bfloat162float (arg.x ), SCALE, OFFSET), MINIMUM );
4825+ float b = fmaxf ( fmaf (bfloat162float (arg.y ), SCALE, OFFSET), MINIMUM );
48144826
48154827 return {
48164828 transmute<__bfloat16>(uint16_t (transmute<uint32_t >(a))),
@@ -4819,34 +4831,67 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
48194831#endif
48204832} // namespace approx
48214833
4822- #define KERNEL_FLOAT_DEFINE_APPROX_FUN (FULL_NAME, FUN, DEG ) \
4823- namespace detail { \
4824- template <int Degree> \
4825- struct apply_impl <approx_level_policy<Degree>, ops::FUN<half_t >, 2 , half_t , half_t > { \
4826- KERNEL_FLOAT_INLINE static void \
4827- call (ops::FUN<half_t > fun, half_t * output, const half_t * input) { \
4828- half2_t res = approx::FUN<Degree>(half2_t {input[0 ], input[1 ]}); \
4829- output[0 ] = res.x ; \
4830- output[1 ] = res.y ; \
4831- } \
4832- }; \
4833- template <> \
4834- struct apply_impl <approx_policy, ops::FUN<half_t >, 2 , half_t , half_t >: \
4835- apply_impl<approx_level_policy<DEG>, ops::FUN<half_t >, 2 , half_t , half_t > {}; \
4836- } \
4837- \
4838- template <int Level = -1 , typename V> \
4839- KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
4840- return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
4841- }
4842-
4843- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_sin, sin, 4 )
4844- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_cos, cos, 4 )
4845- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_rsqrt, rsqrt, 1 )
4846- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_sqrt, sqrt, 1 )
4847- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_rcp, rcp, 1 )
4848- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_exp, exp, 0 )
4849- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_log, log, 0 )
4834+ namespace detail {
4835+ template <int Level, typename F, typename T>
4836+ struct apply_impl <approx_level_policy<Level>, F, 1 , T, T> {
4837+ KERNEL_FLOAT_INLINE static void call (F fun, T* output, const T* input) {
4838+ T in2[2 ], out2[2 ];
4839+ out2[0 ] = input[0 ];
4840+ apply_impl<approx_level_policy<Level>, F, 2 , T, T>::call (fun, out2, in2);
4841+ output[0 ] = out2[0 ];
4842+ }
4843+ };
4844+ } // namespace detail
4845+
4846+ #define KERNEL_FLOAT_DEFINE_APPROX_IMPL (T, FUN, DEFAULT_LEVEL ) \
4847+ namespace detail { \
4848+ template <int Degree> \
4849+ struct apply_impl <approx_level_policy<Degree>, ops::FUN<T>, 2 , T, T> { \
4850+ KERNEL_FLOAT_INLINE static void call (ops::FUN<T>, T* output, const T* input) { \
4851+ auto res = approx::FUN<Degree>({input[0 ], input[1 ]}); \
4852+ output[0 ] = res.x ; \
4853+ output[1 ] = res.y ; \
4854+ } \
4855+ }; \
4856+ \
4857+ template <> \
4858+ struct apply_impl <approx_policy, ops::FUN<T>, 2 , T, T>: \
4859+ apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2 , T, T> {}; \
4860+ }
4861+
4862+ #if KERNEL_FLOAT_FP16_AVAILABLE
4863+ KERNEL_FLOAT_DEFINE_APPROX_IMPL (half_t , sin, 4 )
4864+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , cos, 4 )
4865+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rsqrt, 1 )
4866+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , sqrt, 1 )
4867+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rcp, 1 )
4868+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , exp, 0 )
4869+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , log, 0 )
4870+ #endif
4871+
4872+ #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4873+ KERNEL_FLOAT_DEFINE_APPROX_IMPL (bfloat16_t , cos, 4 )
4874+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , sin, 4 )
4875+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , rcp, 1 )
4876+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , rsqrt, 1 )
4877+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , sqrt, 1 )
4878+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , exp, 0 )
4879+ // KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
4880+ #endif
4881+
4882+ #define KERNEL_FLOAT_DEFINE_APPROX_FUN (FUN ) \
4883+ template <int Level = -1 , typename V> \
4884+ KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
4885+ return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
4886+ }
4887+
4888+ KERNEL_FLOAT_DEFINE_APPROX_FUN (sin)
4889+ KERNEL_FLOAT_DEFINE_APPROX_FUN (cos)
4890+ KERNEL_FLOAT_DEFINE_APPROX_FUN (rsqrt)
4891+ KERNEL_FLOAT_DEFINE_APPROX_FUN (sqrt)
4892+ KERNEL_FLOAT_DEFINE_APPROX_FUN (rcp)
4893+ KERNEL_FLOAT_DEFINE_APPROX_FUN (exp)
4894+ KERNEL_FLOAT_DEFINE_APPROX_FUN (log)
48504895
48514896} // namespace kernel_float
48524897#ifndef KERNEL_FLOAT_FP8_H
0 commit comments