@@ -3818,7 +3818,7 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
38183818 float sumf = 0;
38193819
38203820#if defined(__ARM_FEATURE_SVE)
3821- if (svcntb() == QK8_0) {
3821+ if (lm_ggml_sve_cnt_b == QK8_0) {
38223822 const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
38233823 const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
38243824
@@ -4190,15 +4190,18 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
41904190 sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
41914191#endif
41924192 for (; ib < nb; ++ib) {
4193- int sumi = 0;
4193+ int sumi0 = 0;
4194+ int sumi1 = 0;
41944195
41954196 for (int j = 0; j < qk/2; ++j) {
41964197 const int v0 = (x[ib].qs[j] & 0x0F) - 8;
41974198 const int v1 = (x[ib].qs[j] >> 4) - 8;
41984199
4199- sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]);
4200+ sumi0 += (v0 * y[ib].qs[j]);
4201+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
42004202 }
42014203
4204+ int sumi = sumi0 + sumi1;
42024205 sumf += sumi*LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d);
42034206 }
42044207
@@ -4474,15 +4477,18 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
44744477 sumf = hsum_float_8(acc) + summs;
44754478#endif
44764479 for (; ib < nb; ++ib) {
4477- int sumi = 0;
4480+ int sumi0 = 0;
4481+ int sumi1 = 0;
44784482
44794483 for (int j = 0; j < qk/2; ++j) {
44804484 const int v0 = (x[ib].qs[j] & 0x0F);
44814485 const int v1 = (x[ib].qs[j] >> 4);
44824486
4483- sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]);
4487+ sumi0 += (v0 * y[ib].qs[j]);
4488+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
44844489 }
44854490
4491+ int sumi = sumi0 + sumi1;
44864492 sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s);
44874493 }
44884494
@@ -4823,18 +4829,21 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
48234829 uint32_t qh;
48244830 memcpy(&qh, x[ib].qh, sizeof(qh));
48254831
4826- int sumi = 0;
4832+ int sumi0 = 0;
4833+ int sumi1 = 0;
48274834
48284835 for (int j = 0; j < qk/2; ++j) {
48294836 const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
48304837 const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
48314838
4832- const int32_t x0 = (( x[ib].qs[j] & 0x0F) | xh_0) - 16;
4833- const int32_t x1 = (( x[ib].qs[j] >> 4) | xh_1) - 16;
4839+ const int32_t x0 = (int8_t)((( x[ib].qs[j] & 0x0F) | xh_0) - 16) ;
4840+ const int32_t x1 = (int8_t)((( x[ib].qs[j] >> 4) | xh_1) - 16) ;
48344841
4835- sumi += (x0 * y[ib].qs[j]) + (x1 * y[ib].qs[j + qk/2]);
4842+ sumi0 += (x0 * y[ib].qs[j]);
4843+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
48364844 }
48374845
4846+ int sumi = sumi0 + sumi1;
48384847 sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi;
48394848 }
48404849
@@ -5194,7 +5203,8 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
51945203 uint32_t qh;
51955204 memcpy(&qh, x[ib].qh, sizeof(qh));
51965205
5197- int sumi = 0;
5206+ int sumi0 = 0;
5207+ int sumi1 = 0;
51985208
51995209 for (int j = 0; j < qk/2; ++j) {
52005210 const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
@@ -5203,9 +5213,11 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
52035213 const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
52045214 const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
52055215
5206- sumi += (x0 * y[ib].qs[j]) + (x1 * y[ib].qs[j + qk/2]);
5216+ sumi0 += (x0 * y[ib].qs[j]);
5217+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
52075218 }
52085219
5220+ int sumi = sumi0 + sumi1;
52095221 sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s);
52105222 }
52115223
@@ -5291,7 +5303,7 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
52915303 float sumf = 0;
52925304
52935305#if defined(__ARM_FEATURE_SVE)
5294- if (svcntb() == QK8_0) {
5306+ if (lm_ggml_sve_cnt_b == QK8_0) {
52955307 svfloat32_t sumv0 = svdup_n_f32(0.0f);
52965308 svfloat32_t sumv1 = svdup_n_f32(0.0f);
52975309
@@ -6437,22 +6449,22 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
64376449 // compute mask for subtraction
64386450 vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
64396451 vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
6440- vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m (vmask_0, q3_0, 0x4, vl);
6452+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu (vmask_0, q3_0 , q3_0, 0x4, vl);
64416453 m <<= 1;
64426454
64436455 vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
64446456 vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
6445- vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m (vmask_1, q3_1, 0x4, vl);
6457+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu (vmask_1, q3_1 , q3_1, 0x4, vl);
64466458 m <<= 1;
64476459
64486460 vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
64496461 vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
6450- vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m (vmask_2, q3_2, 0x4, vl);
6462+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu (vmask_2, q3_2 , q3_2, 0x4, vl);
64516463 m <<= 1;
64526464
64536465 vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
64546466 vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
6455- vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m (vmask_3, q3_3, 0x4, vl);
6467+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu (vmask_3, q3_3 , q3_3, 0x4, vl);
64566468 m <<= 1;
64576469
64586470 // load Q8 and take product with Q3
@@ -7708,13 +7720,13 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
77087720 vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
77097721 vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
77107722 vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
7711- vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m (vmask_1, q5_a, 16, vl);
7723+ vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu (vmask_1, q5_a , q5_a, 16, vl);
77127724 m <<= 1;
77137725
77147726 vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
77157727 vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
77167728 vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
7717- vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m (vmask_2, q5_l, 16, vl);
7729+ vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu (vmask_2, q5_l , q5_l, 16, vl);
77187730 m <<= 1;
77197731
77207732 vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
0 commit comments