@@ -1002,7 +1002,39 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
1002
1002
}
1003
1003
#endif
1004
1004
1005
- #if defined(__ARM_NEON ) && defined(__aarch64__ )
1005
+ #if defined(__ARM_FEATURE_SVE ) && defined(__aarch64__ )
1006
+
1007
+ inline static svfloat32_t ggml_v_expf (svbool_t pg , svfloat32_t x ) {
1008
+ const svfloat32_t r = svdup_n_f32_x (pg , 0x1.8p23f );
1009
+ const svfloat32_t z = svmla_n_f32_x (pg , r , x , 0x1.715476p+0f );
1010
+ const svfloat32_t n = svsub_f32_x (pg , z , r );
1011
+ const svfloat32_t b = svmls_n_f32_x (pg , svmls_n_f32_x (pg , x , n , 0x1.62e4p-1f ), n , 0x1.7f7d1cp-20f );
1012
+ const svuint32_t e = svlsl_n_u32_x (pg , svreinterpret_u32_f32 (z ), 23 );
1013
+ const svfloat32_t k = svreinterpret_f32_u32 (svadd_u32_x (pg , e , svreinterpret_u32_f32 (svdup_n_f32_x (pg , 1 ))));
1014
+ const svbool_t c = svacgt_n_f32 (pg , n , 126 );
1015
+ const svfloat32_t u = svmul_f32_x (pg , b , b );
1016
+ const svfloat32_t j = svmla_f32_x (pg ,
1017
+ svmul_n_f32_x (pg , b , 0x1.ffffecp-1f ),
1018
+ svmla_f32_x (pg , svmla_f32_x (pg , svdup_n_f32_x (pg , 0x1.fffdb6p-2f ), svdup_n_f32_x (pg , 0x1.555e66p-3f ), b ),
1019
+ svmla_f32_x (pg , svdup_n_f32_x (pg , 0x1.573e2ep-5f ), svdup_n_f32_x (pg , 0x1.0e4020p-7f ), b ), u ), u );
1020
+ const svuint32_t d = svdup_n_u32_z (svcmple_n_f32 (pg , n , 0.0 ), 0x82000000 );
1021
+ const svfloat32_t s1 = svreinterpret_f32_u32 (svadd_n_u32_x (pg , d , 0x7f000000 ));
1022
+ const svfloat32_t s2 = svreinterpret_f32_u32 (svsub_u32_x (pg , e , d ));
1023
+ return svsel_f32 (svacgt_f32 (pg , n , svdup_n_f32_x (pg , 192 )), svmul_f32_x (pg , s1 , s1 ),
1024
+ svsel_f32 (c , svmul_f32_x (pg , svmla_f32_x (pg , s2 , s2 , j ), s1 ), svmla_f32_x (pg , k , k , j )));
1025
+ }
1026
+
1027
+ // computes silu x/(1+exp(-x)) in single precision vector
1028
+ inline static svfloat32_t ggml_v_silu (svbool_t pg , svfloat32_t x ) {
1029
+ const svfloat32_t one = svdup_n_f32_x (pg , 1.0f );
1030
+ const svfloat32_t zero = svdup_n_f32_x (pg , 0.0f );
1031
+ const svfloat32_t neg_x = svsub_f32_x (pg , zero , x );
1032
+ const svfloat32_t exp_neg_x = ggml_v_expf (pg , neg_x );
1033
+ const svfloat32_t one_plus_exp_neg_x = svadd_f32_x (pg , one , exp_neg_x );
1034
+ return svdiv_f32_x (pg , x , one_plus_exp_neg_x );
1035
+ }
1036
+
1037
+ #elif defined(__ARM_NEON ) && defined(__aarch64__ )
1006
1038
1007
1039
// adapted from arm limited optimized routine
1008
1040
// the maximum error is 1.45358 plus 0.5 ulps
0 commit comments