@@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
501501}
502502
503503static __m256i lasx_extu8_16 (__m128i a ) {
504- __m128i zero = __lsx_vldi (0 );
505- __m128i vlo = __lsx_vilvl_b (zero , a );
506- __m128i vhi = __lsx_vilvh_b (zero , a );
507- return lasx_set_q (vhi , vlo );
504+ return __lasx_vext2xv_hu_bu (____m256i (a ));
508505}
509506
510507static __m256i lasx_ext8_16 (__m128i a ) {
511- __m128i sign = __lsx_vslti_b (a , 0 );
512- __m128i vlo = __lsx_vilvl_b (sign , a );
513- __m128i vhi = __lsx_vilvh_b (sign , a );
514- return lasx_set_q (vhi , vlo );
508+ return __lasx_vext2xv_h_b (____m256i (a ));
515509}
516510
517511static __m256i lasx_ext16_32 (__m128i a ) {
518- __m256i tmp1 ;
519- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 0 ), 0 );
520- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 1 ), 1 );
521- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 2 ), 2 );
522- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 3 ), 3 );
523- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 4 ), 4 );
524- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 5 ), 5 );
525- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 6 ), 6 );
526- tmp1 = __lasx_xvinsgr2vr_w (tmp1 , __lsx_vpickve2gr_h (a , 7 ), 7 );
527- return tmp1 ;
512+ return __lasx_vext2xv_w_h (____m256i (a ));
528513}
529514
530515static __m128i lasx_extracti128 ( __m256i a , int pos ) {
@@ -592,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
592577// horizontally add 8 floats
593578static inline float hsum_float_8 (const __m256 x ) {
594579 __m128 res = lasx_extractf128 (x , 1 );
595- ft_union tmp ;
596580 res = __lsx_vfadd_s (res , lasx_extractf128 (x , 0 ));
597581 res = __lsx_vfadd_s (res , (__m128 )__lsx_vpickod_d ((__m128i )res , (__m128i )res ));
598582 res = __lsx_vfadd_s (res , (__m128 )__lsx_vinsgr2vr_w (__lsx_vldi (0 ), __lsx_vpickve2gr_w (res , 1 ), 0 ));
599- tmp .i = __lsx_vpickve2gr_w (res , 0 );
600- return tmp .f ;
583+ return ((v4f32 )res )[0 ];
601584}
602585
603586// horizontally add 8 int32_t
@@ -939,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
939922
940923#elif defined(__loongarch_asx )
941924 for (int i = 0 ; i < nb ; i ++ ) {
942- ft_union fi ;
943925 __m256 v0 = (__m256 )__lasx_xvld ( x , 0 );
944926 __m256 v1 = (__m256 )__lasx_xvld ( x , 32 );
945927 __m256 v2 = (__m256 )__lasx_xvld ( x , 64 );
@@ -957,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
957939 max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vpickod_d ((__m128i ) max4 , (__m128i )max4 ) );
958940 __m128 tmp = max4 ;
959941 max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vinsgr2vr_w (tmp , __lsx_vpickve2gr_w ( max4 , 1 ), 0 ));
960- fi .i = __lsx_vpickve2gr_w ( (__m128i )max4 , 0 );
961- const float max_scalar = fi .f ;
942+ const float max_scalar = ((v4f32 )max4 )[0 ];
962943
963944 // Quantize these floats
964945 const float d = max_scalar / 127.f ;
@@ -1263,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
12631244
12641245#elif defined(__loongarch_asx )
12651246 for (int i = 0 ; i < nb ; i ++ ) {
1266- ft_union ft ;
12671247 __m256 v0 = (__m256 )__lasx_xvld ( x , 0 );
12681248 __m256 v1 = (__m256 )__lasx_xvld ( x , 32 );
12691249 __m256 v2 = (__m256 )__lasx_xvld ( x , 64 );
@@ -1281,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
12811261 max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vpickod_d ((__m128i ) max4 , (__m128i )max4 ) );
12821262 __m128 tmp = max4 ;
12831263 max4 = __lsx_vfmax_s ( max4 , (__m128 )__lsx_vextrins_w ((__m128i )tmp , (__m128i )max4 , 0x10 ));
1284- ft .i = __lsx_vpickve2gr_w ( (__m128i )max4 , 0 );
1285- const float max_scalar = ft .f ;
1264+ const float max_scalar = ((v4f32 )max4 )[0 ];
12861265
12871266 // Quantize these floats
12881267 const float d = max_scalar / 127.f ;
@@ -6154,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
61546133 acc_m = __lsx_vfadd_s (acc_m , (__m128 )tmp1 );
61556134
61566135
6157- ft_union fi ;
6158- fi .i = __lsx_vpickve2gr_w (acc_m , 0 );
6159- * s = hsum_float_8 (acc ) + fi .f ;
6136+ * s = hsum_float_8 (acc ) + ((v4f32 )acc_m )[0 ];
61606137#else
61616138
61626139 const uint8_t * scales = (const uint8_t * )& utmp [0 ];
0 commit comments