Skip to content

Commit b4b2214

Browse files
MQ-mengqingtaronaeo
authored andcommitted
ggml: resolve pr merge via cherry-pick 225bbbf
Signed-off-by: Aaron Teo <[email protected]>
1 parent 5796caf commit b4b2214

File tree

3 files changed

+22
-57
lines changed

3 files changed

+22
-57
lines changed

ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -511,21 +511,15 @@ inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
511511
#endif
512512

513513
#if defined(__loongarch_asx)
514-
515-
typedef union {
516-
int32_t i;
517-
float f;
518-
} ft_union;
519-
520514
/* float type data load instructions */
521-
static __m128 __lsx_vreplfr2vr_s(float val) {
522-
ft_union fi_tmpval = {.f = val};
523-
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
515+
static __m128 __lsx_vreplfr2vr_s(const float val) {
516+
v4f32 res = {val, val, val, val};
517+
return (__m128)res;
524518
}
525519

526-
static __m256 __lasx_xvreplfr2vr_s(float val) {
527-
ft_union fi_tmpval = {.f = val};
528-
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
520+
static __m256 __lasx_xvreplfr2vr_s(const float val) {
521+
v8f32 res = {val, val, val, val, val, val, val, val};
522+
return (__m256)res;
529523
}
530524
#endif
531525

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -434,30 +434,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
434434
}
435435

436436
static __m256i lasx_extu8_16(__m128i a) {
437-
__m128i zero = __lsx_vldi(0);
438-
__m128i vlo = __lsx_vilvl_b(zero, a);
439-
__m128i vhi = __lsx_vilvh_b(zero, a);
440-
return lasx_set_q(vhi, vlo);
437+
return __lasx_vext2xv_hu_bu(____m256i(a));
441438
}
442439

443440
static __m256i lasx_ext8_16(__m128i a) {
444-
__m128i sign = __lsx_vslti_b(a, 0);
445-
__m128i vlo = __lsx_vilvl_b(sign, a);
446-
__m128i vhi = __lsx_vilvh_b(sign, a);
447-
return lasx_set_q(vhi, vlo);
441+
return __lasx_vext2xv_h_b(____m256i(a));
448442
}
449443

450444
static __m256i lasx_ext16_32(__m128i a) {
451-
__m256i tmp1;
452-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
453-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
454-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
455-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
456-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
457-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
458-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
459-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
460-
return tmp1;
445+
return __lasx_vext2xv_w_h(____m256i(a));
461446
}
462447

463448
static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -580,12 +565,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
580565
// horizontally add 8 floats
581566
static inline float hsum_float_8(const __m256 x) {
582567
__m128 res = lasx_extractf128(x, 1);
583-
ft_union tmp;
584568
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
585569
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
586570
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
587-
tmp.i = __lsx_vpickve2gr_w(res, 0);
588-
return tmp.f;
571+
return ((v4f32)res)[0];
589572
}
590573

591574
// horizontally add 8 int32_t
@@ -927,7 +910,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
927910

928911
#elif defined(__loongarch_asx)
929912
for (int i = 0; i < nb; i++) {
930-
ft_union fi;
931913
__m256 v0 = (__m256)__lasx_xvld( x , 0);
932914
__m256 v1 = (__m256)__lasx_xvld( x , 32);
933915
__m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -945,8 +927,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
945927
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
946928
__m128 tmp = max4;
947929
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
948-
fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
949-
const float max_scalar = fi.f;
930+
const float max_scalar = ((v4f32)max4)[0];
950931

951932
// Quantize these floats
952933
const float d = max_scalar / 127.f;
@@ -1283,7 +1264,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
12831264

12841265
#elif defined(__loongarch_asx)
12851266
for (int i = 0; i < nb; i++) {
1286-
ft_union ft;
12871267
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
12881268
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
12891269
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1301,8 +1281,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
13011281
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
13021282
__m128 tmp = max4;
13031283
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
1304-
ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
1305-
const float max_scalar = ft.f;
1284+
const float max_scalar = ((v4f32)max4)[0];
13061285

13071286
// Quantize these floats
13081287
const float d = max_scalar / 127.f;
@@ -6211,9 +6190,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
62116190
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
62126191

62136192

6214-
ft_union fi;
6215-
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
6216-
*s = hsum_float_8(acc) + fi.f ;
6193+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
62176194
#elif defined(__VXE__) || defined(__VXE2__)
62186195
const uint8x16_t v_lm = vec_splat_u8(0x0F);
62196196
const int32x4_t v_z = vec_splat_s32(0);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,29 +1080,23 @@ do { \
10801080
#define GGML_F16_STEP 32
10811081
#define GGML_F16_EPR 8
10821082

1083-
// F16 arithmetic is not supported by AVX, so we use F32 instead
1083+
// F16 arithmetic is not supported by LASX, so we use F32 instead
10841084

10851085
#define GGML_F32Cx8 __m256
10861086
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
10871087
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
10881088

10891089
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
1090-
float tmp[8];
1091-
1092-
for (int i = 0; i < 8; i++) {
1093-
tmp[i] = GGML_FP16_TO_FP32(x[i]);
1094-
}
1095-
1096-
return (__m256)__lasx_xvld(tmp, 0);
1090+
__m256i a;
1091+
memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
1092+
a = __lasx_xvpermi_d(a, 0 | (1 << 4));
1093+
return __lasx_xvfcvtl_s_h(a);
10971094
}
1098-
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1099-
float arr[8];
11001095

1101-
__lasx_xvst(y, arr, 0);
1102-
1103-
for (int i = 0; i < 8; i++) {
1104-
x[i] = GGML_FP32_TO_FP16(arr[i]);
1105-
}
1096+
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1097+
__m256i a = __lasx_xvfcvt_h_s(y, y);
1098+
a = __lasx_xvpermi_d(a, 0 | (2 << 2));
1099+
memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
11061100
}
11071101
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
11081102
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)

0 commit comments

Comments
 (0)