@@ -1002,7 +1002,39 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
10021002 }
10031003#endif
10041004
1005- #if defined(__ARM_NEON ) && defined(__aarch64__ )
1005+ #if defined(__ARM_FEATURE_SVE ) && defined(__aarch64__ )
1006+
1007+ inline static svfloat32_t ggml_v_expf (svbool_t pg , svfloat32_t x ) {
1008+ const svfloat32_t r = svdup_n_f32_x (pg , 0x1.8p23f );
1009+ const svfloat32_t z = svmla_n_f32_x (pg , r , x , 0x1.715476p+0f );
1010+ const svfloat32_t n = svsub_f32_x (pg , z , r );
1011+ const svfloat32_t b = svmls_n_f32_x (pg , svmls_n_f32_x (pg , x , n , 0x1.62e4p-1f ), n , 0x1.7f7d1cp-20f );
1012+ const svuint32_t e = svlsl_n_u32_x (pg , svreinterpret_u32_f32 (z ), 23 );
1013+ const svfloat32_t k = svreinterpret_f32_u32 (svadd_u32_x (pg , e , svreinterpret_u32_f32 (svdup_n_f32_x (pg , 1 ))));
1014+ const svbool_t c = svacgt_n_f32 (pg , n , 126 );
1015+ const svfloat32_t u = svmul_f32_x (pg , b , b );
1016+ const svfloat32_t j = svmla_f32_x (pg ,
1017+ svmul_n_f32_x (pg , b , 0x1.ffffecp-1f ),
1018+ svmla_f32_x (pg , svmla_f32_x (pg , svdup_n_f32_x (pg , 0x1.fffdb6p-2f ), svdup_n_f32_x (pg , 0x1.555e66p-3f ), b ),
1019+ svmla_f32_x (pg , svdup_n_f32_x (pg , 0x1.573e2ep-5f ), svdup_n_f32_x (pg , 0x1.0e4020p-7f ), b ), u ), u );
1020+ const svuint32_t d = svdup_n_u32_z (svcmple_n_f32 (pg , n , 0.0 ), 0x82000000 );
1021+ const svfloat32_t s1 = svreinterpret_f32_u32 (svadd_n_u32_x (pg , d , 0x7f000000 ));
1022+ const svfloat32_t s2 = svreinterpret_f32_u32 (svsub_u32_x (pg , e , d ));
1023+ return svsel_f32 (svacgt_f32 (pg , n , svdup_n_f32_x (pg , 192 )), svmul_f32_x (pg , s1 , s1 ),
1024+ svsel_f32 (c , svmul_f32_x (pg , svmla_f32_x (pg , s2 , s2 , j ), s1 ), svmla_f32_x (pg , k , k , j )));
1025+ }
1026+
1027+ // computes silu x/(1+exp(-x)) in single precision vector
1028+ inline static svfloat32_t ggml_v_silu (svbool_t pg , svfloat32_t x ) {
1029+ const svfloat32_t one = svdup_n_f32_x (pg , 1.0f );
1030+ const svfloat32_t zero = svdup_n_f32_x (pg , 0.0f );
1031+ const svfloat32_t neg_x = svsub_f32_x (pg , zero , x );
1032+ const svfloat32_t exp_neg_x = ggml_v_expf (pg , neg_x );
1033+ const svfloat32_t one_plus_exp_neg_x = svadd_f32_x (pg , one , exp_neg_x );
1034+ return svdiv_f32_x (pg , x , one_plus_exp_neg_x );
1035+ }
1036+
1037+ #elif defined(__ARM_NEON ) && defined(__aarch64__ )
10061038
10071039// adapted from arm limited optimized routine
10081040// the maximum error is 1.45358 plus 0.5 ulps
0 commit comments