@@ -963,7 +963,7 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
963963#define GGML_F16_EPR GGML_F32_EPR
964964
965965static inline float32x4_t __lzs_f16cx4_load (const ggml_fp16_t * x ) {
966- #ifdef __NNPA__
966+ #if defined( __NNPA__ )
967967 uint16x8_t v_x = vec_xl (0 , (const ggml_fp16_t * )x );
968968 uint16x8_t nnpa_dlf16 = vec_convert_from_fp16 (v_x , 0 );
969969 return vec_extend_to_fp32_hi (nnpa_dlf16 , 0 );
@@ -980,8 +980,17 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
980980#endif
981981}
982982
983- // TODO: check why this function is not being hit at all
984983static inline void __lzs_f16cx4_store (ggml_fp16_t * x , float32x4_t v_y ) {
984+ #if defined(__NNPA__ )
985+ float32x4_t v_zero = vec_splats (0.0f );
986+ uint16x8_t v_xd = vec_round_from_fp32 (v_y , v_zero , 0 );
987+ uint16x8_t v_x = vec_convert_to_fp16 (v_xd , 0 );
988+
989+ x [0 ] = vec_extract (v_x , 0 );
990+ x [1 ] = vec_extract (v_x , 1 );
991+ x [2 ] = vec_extract (v_x , 2 );
992+ x [3 ] = vec_extract (v_x , 3 );
993+ #else
985994 float arr [4 ];
986995
987996 // note: keep type-cast here to prevent compiler bugs
@@ -991,6 +1000,7 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
9911000 for (int i = 0 ; i < 4 ; i ++ ) {
9921001 x [i ] = GGML_FP32_TO_FP16 (arr [i ]);
9931002 }
1003+ #endif
9941004}
9951005
9961006#define GGML_F16_VEC GGML_F32x4
0 commit comments