@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
297297static const uint64_t table_b2b_1 [1 << 8 ] = { B8 (10 , 00 ) }; // (!b) << 4
298298#endif
299299
300+ #if defined(__loongarch_sx )
301+
302+ static __m128i lsx_packs_w (__m128i a , __m128i b ) {
303+ __m128i tmp , tmp1 ;
304+ tmp = __lsx_vsat_w (a , 15 );
305+ tmp1 = __lsx_vsat_w (b , 15 );
306+ return __lsx_vpickev_h (tmp1 , tmp );
307+ }
308+
309+ static __m128i lsx_packs_h (__m128i a , __m128i b ) {
310+ __m128i tmp , tmp1 ;
311+ tmp = __lsx_vsat_h (a , 7 );
312+ tmp1 = __lsx_vsat_h (b , 7 );
313+ return __lsx_vpickev_b (tmp1 , tmp );
314+ }
315+
316+ static __m128i lsx_packus_h (__m128i a , __m128i b ) {
317+ __m128i tmp , tmp1 ;
318+ tmp = __lsx_vsat_hu (a , 7 );
319+ tmp1 = __lsx_vsat_hu (b , 7 );
320+ return __lsx_vpickev_b (tmp1 , tmp );
321+ }
322+
323+ static __m128i lsx_maddubs_h (__m128i a , __m128i b ) {
324+ __m128i tmp1 , tmp2 ;
325+ tmp1 = __lsx_vmulwev_h_b (a , b );
326+ tmp2 = __lsx_vmulwod_h_b (a , b );
327+ return __lsx_vsadd_h (tmp1 , tmp2 );
328+ }
329+
330+ static __m128i lsx_madd_h (__m128i a , __m128i b ) {
331+ __m128i tmp1 , tmp2 ;
332+ tmp1 = __lsx_vmulwev_w_h (a , b );
333+ tmp2 = __lsx_vmulwod_w_h (a , b );
334+ return __lsx_vadd_w (tmp1 , tmp2 );
335+ }
336+
337+ static __m128i lsx_set_w (int32_t a , int32_t b , int32_t c , int32_t d ) {
338+ v4i32 __ret = {d , c , b , a };
339+ return (__m128i )__ret ;
340+ }
341+
342+ static __m128i lsx_shuffle_b (__m128i a , __m128i b ) {
343+ __m128i mask_f , zero , tmp0 , tmp2 , mask ;
344+ int f = 0x8f ;
345+ mask_f = __lsx_vreplgr2vr_b (f );
346+ zero = __lsx_vldi (0 );
347+ tmp0 = __lsx_vand_v (b , mask_f ); // get mask with low 4 bit and sign bits
348+ tmp0 = __lsx_vori_b (tmp0 , 0x10 ); // make each mask or with 0x10 prepare for positive
349+ mask = __lsx_vsle_b (zero , tmp0 ); // if mask >= 0, set mask
350+ tmp2 = __lsx_vand_v (tmp0 , mask ); // maskout the in2 < ones
351+ return __lsx_vshuf_b (a , zero , tmp2 );
352+ }
353+
354+ static __m128i lsx_hadd_h (__m128i a , __m128i b ) {
355+ __m128i tmp1 = __lsx_vpickev_h (b , a );
356+ __m128i tmp2 = __lsx_vpickod_h (b , a );
357+ return __lsx_vadd_h (tmp1 , tmp2 );
358+ }
359+
360+ static __m128i lsx_hadd_w (__m128i a , __m128i b ) {
361+ __m128i tmp1 = __lsx_vpickev_w (b , a );
362+ __m128i tmp2 = __lsx_vpickod_w (b , a );
363+ return __lsx_vadd_w (tmp1 , tmp2 );
364+ }
365+
366+ static __m128 lsx_hadd_s (__m128 a , __m128 b ) {
367+ __m128 tmp1 = (__m128 )__lsx_vpickev_w ((__m128i )b , (__m128i )a );
368+ __m128 tmp2 = (__m128 )__lsx_vpickod_w ((__m128i )b , (__m128i )a );
369+
370+ return __lsx_vfadd_s (tmp1 , tmp2 );
371+ }
372+
373+ static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
374+ __m128 res_0 = lsx_hadd_s (a , b );
375+ __m128 res_1 = lsx_hadd_s (c , d );
376+ __m128 res = lsx_hadd_s (res_0 , res_1 );
377+ res = lsx_hadd_s (res , res );
378+ res = lsx_hadd_s (res , res );
379+
380+ return ((v4f32 )res )[0 ];
381+ }
382+ #endif
383+
300384#if defined(__loongarch_asx )
301385
302386#ifdef __clang__
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
395479 return (__m256i )__ret ;
396480}
397481
398- static __m128i lsx_set_w (int32_t a , int32_t b , int32_t c , int32_t d ) {
399- v4i32 __ret = {d , c , b , a };
400- return (__m128i )__ret ;
401- }
402-
403482static __m256i lasx_set_d (int64_t a , int64_t b , int64_t c , int64_t d ) {
404483 v4i64 __ret = {d , c , b , a };
405484 return (__m256i )__ret ;
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
409488 return lasx_set_q (x , y );
410489}
411490
412- static __m128i lsx_shuffle_b (__m128i a , __m128i b ) {
413- __m128i mask_f , zero , tmp0 , tmp2 , mask ;
414- int f = 0x8f ;
415- mask_f = __lsx_vreplgr2vr_b (f );
416- zero = __lsx_vldi (0 );
417- tmp0 = __lsx_vand_v (b , mask_f ); // get mask with low 4 bit and sign bits
418- tmp0 = __lsx_vori_b (tmp0 , 0x10 ); // make each mask or with 0x10 prepare for positive
419- mask = __lsx_vsle_b (zero , tmp0 ); // if mask >= 0, set mask
420- tmp2 = __lsx_vand_v (tmp0 , mask ); // maskout the in2 < ones
421- return __lsx_vshuf_b (a , zero , tmp2 );
422- }
423-
424491static __m256i lasx_shuffle_b (__m256i a , __m256i b ) {
425492 __m256i mask_f , zero , tmp0 , tmp2 , mask ;
426493 int f = 0x8f ;
@@ -482,25 +549,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
482549 return ret ;
483550}
484551
485- static __m128i lsx_hadd_h (__m128i a , __m128i b ) {
486- __m128i tmp1 = __lsx_vpickev_h (b , a );
487- __m128i tmp2 = __lsx_vpickod_h (b , a );
488- return __lsx_vadd_h (tmp1 , tmp2 );
489- }
490-
491- static __m128i lsx_hadd_w (__m128i a , __m128i b ) {
492- __m128i tmp1 = __lsx_vpickev_w (b , a );
493- __m128i tmp2 = __lsx_vpickod_w (b , a );
494- return __lsx_vadd_w (tmp1 , tmp2 );
495- }
496-
497- static __m128 lsx_hadd_s (__m128 a , __m128 b ) {
498- __m128 tmp1 = (__m128 )__lsx_vpickev_w ((__m128i )b , (__m128i )a );
499- __m128 tmp2 = (__m128 )__lsx_vpickod_w ((__m128i )b , (__m128i )a );
500-
501- return __lsx_vfadd_s (tmp1 , tmp2 );
502- }
503-
504552static __m256i lasx_maddubs_h (__m256i a , __m256i b ) {
505553 __m256i tmp1 , tmp2 ;
506554 tmp1 = __lasx_xvmulwev_h_b (a , b );
@@ -529,42 +577,6 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
529577 return __lasx_xvpickev_b (tmp1 , tmp );
530578}
531579
532- static __m128i lsx_packs_w (__m128i a , __m128i b ) {
533- __m128i tmp , tmp1 ;
534- tmp = __lsx_vsat_w (a , 15 );
535- tmp1 = __lsx_vsat_w (b , 15 );
536- return __lsx_vpickev_h (tmp1 , tmp );
537- }
538-
539- static __m128i lsx_packs_h (__m128i a , __m128i b ) {
540- __m128i tmp , tmp1 ;
541- tmp = __lsx_vsat_h (a , 7 );
542- tmp1 = __lsx_vsat_h (b , 7 );
543- return __lsx_vpickev_b (tmp1 , tmp );
544- }
545-
546- static __m128i lsx_packus_h (__m128i a , __m128i b ) {
547- __m128i tmp , tmp1 ;
548- tmp = __lsx_vsat_hu (a , 7 );
549- tmp1 = __lsx_vsat_hu (b , 7 );
550- return __lsx_vpickev_b (tmp1 , tmp );
551- }
552-
553-
554- static __m128i lsx_maddubs_h (__m128i a , __m128i b ) {
555- __m128i tmp1 , tmp2 ;
556- tmp1 = __lsx_vmulwev_h_b (a , b );
557- tmp2 = __lsx_vmulwod_h_b (a , b );
558- return __lsx_vsadd_h (tmp1 , tmp2 );
559- }
560-
561- static __m128i lsx_madd_h (__m128i a , __m128i b ) {
562- __m128i tmp1 , tmp2 ;
563- tmp1 = __lsx_vmulwev_w_h (a , b );
564- tmp2 = __lsx_vmulwod_w_h (a , b );
565- return __lsx_vadd_w (tmp1 , tmp2 );
566- }
567-
568580// multiply int8_t, add results pairwise twice
569581static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
570582 // Get absolute values of x vectors
@@ -2232,21 +2244,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
22322244 }
22332245
22342246 sumf = hsum_float_8 (acc );
2247+
22352248#elif defined(__loongarch_sx )
22362249 // set constants
22372250 const __m128i low_mask = __lsx_vreplgr2vr_b (0xF );
22382251 const __m128i off = __lsx_vreplgr2vr_b (8 );
22392252
22402253 // Initialize accumulator with zeros
2241- __m128 acc_0 = __lsx_vldi (0 );
2242- __m128 acc_1 = __lsx_vldi (0 );
2243- __m128 acc_2 = __lsx_vldi (0 );
2244- __m128 acc_3 = __lsx_vldi (0 );
2254+ __m128 acc_0 = ( __m128 ) __lsx_vldi (0 );
2255+ __m128 acc_1 = ( __m128 ) __lsx_vldi (0 );
2256+ __m128 acc_2 = ( __m128 ) __lsx_vldi (0 );
2257+ __m128 acc_3 = ( __m128 ) __lsx_vldi (0 );
22452258
22462259 for (; ib + 1 < nb ; ib += 2 ) {
22472260
22482261 // Compute combined scale for the block 0 and 1
2249- const __m128 d_0_1 = __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib ].d ) * GGML_FP16_TO_FP32 (y [ib ].d ) );
2262+ const __m128 d_0_1 = ( __m128 ) __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib ].d ) * GGML_FP16_TO_FP32 (y [ib ].d ) );
22502263
22512264 const __m128i tmp_0_1 = __lsx_vld ((const __m128i * )x [ib ].qs , 0 );
22522265
@@ -2264,7 +2277,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
22642277 //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
22652278
22662279 // Compute combined scale for the block 2 and 3
2267- const __m128 d_2_3 = __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib + 1 ].d ) * GGML_FP16_TO_FP32 (y [ib + 1 ].d ) );
2280+ const __m128 d_2_3 = ( __m128 ) __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib + 1 ].d ) * GGML_FP16_TO_FP32 (y [ib + 1 ].d ) );
22682281
22692282 const __m128i tmp_2_3 = __lsx_vld ((const __m128i * )x [ib + 1 ].qs , 0 );
22702283
0 commit comments