Skip to content

Commit e5aeb42

Browse files
committed
fix bad merging
1 parent 610b3ac commit e5aeb42

File tree

1 file changed

+167
-1
lines changed

1 file changed

+167
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5986,7 +5986,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
59865986

59875987
uint32_t utmp[4];
59885988

5989-
#ifdef __ARM_NEON
5989+
#ifdef __ARM_FEATURE_SVE
5990+
float sumf = 0;
5991+
for (int i = 0; i < nb; ++i) {
5992+
5993+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5994+
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5995+
5996+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
5997+
5998+
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
5999+
6000+
uint32x2_t mins8 = { 0 };
6001+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
6002+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
6003+
6004+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6005+
utmp[0] &= kmask1;
6006+
6007+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
6008+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
6009+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
6010+
sumf -= dmin * vaddvq_s32(prod);
6011+
6012+
const uint8_t * scales = (const uint8_t *)utmp;
6013+
6014+
const uint8_t * restrict q4 = x[i].qs;
6015+
const int8_t * restrict q8 = y[i].qs;
6016+
6017+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
6018+
const svuint8_t m4b = svdup_n_u8(0xf);
6019+
const svint32_t mzero = svdup_n_s32(0);
6020+
svint32_t sumi1 = svdup_n_s32(0);
6021+
svint32_t sumi1_1 = svdup_n_s32(0);
6022+
svint32_t sumi1_2 = svdup_n_s32(0);
6023+
svint32_t sumi2 = svdup_n_s32(0);
6024+
svint32_t sumi2_1 = svdup_n_s32(0);
6025+
svint32_t sumi2_2 = svdup_n_s32(0);
6026+
switch (vector_length) {
6027+
case 128:
6028+
{
6029+
for (int j = 0; j < QK_K/64; ++j) {
6030+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
6031+
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6032+
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6033+
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
6034+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6035+
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6036+
6037+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
6038+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6039+
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6040+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
6041+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6042+
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6043+
q4 += 32;
6044+
}
6045+
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
6046+
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
6047+
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
6048+
} break;
6049+
case 256:
6050+
case 512:
6051+
{
6052+
for (int j = 0; j < QK_K/64; ++j) {
6053+
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
6054+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
6055+
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
6056+
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6057+
6058+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
6059+
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
6060+
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6061+
}
6062+
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
6063+
} break;
6064+
default:
6065+
assert(false && "Unsupported vector length");
6066+
break;
6067+
}
6068+
}
6069+
*s = sumf;
6070+
#elif defined __ARM_NEON
59906071
const uint8x16_t m4b = vdupq_n_u8(0xf);
59916072
const int32x4_t mzero = vdupq_n_s32(0);
59926073

@@ -7756,6 +7837,91 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
77567837
}
77577838
*s = sumf;
77587839

7840+
#elif defined __riscv_v_intrinsic
7841+
7842+
float sumf = 0;
7843+
for (int i = 0; i < nb; ++i) {
7844+
7845+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7846+
7847+
const uint8_t * restrict q6 = x[i].ql;
7848+
const uint8_t * restrict qh = x[i].qh;
7849+
const int8_t * restrict q8 = y[i].qs;
7850+
7851+
const int8_t * restrict scale = x[i].scales;
7852+
7853+
size_t vl;
7854+
7855+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
7856+
7857+
int sum_t = 0;
7858+
int is = 0;
7859+
7860+
for (int j = 0; j < QK_K/128; ++j) {
7861+
7862+
vl = 32;
7863+
7864+
// load qh
7865+
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
7866+
7867+
// load Q6
7868+
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
7869+
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
7870+
7871+
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
7872+
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
7873+
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
7874+
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
7875+
7876+
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
7877+
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
7878+
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
7879+
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
7880+
7881+
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
7882+
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
7883+
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
7884+
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
7885+
7886+
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
7887+
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
7888+
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
7889+
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
7890+
7891+
// load Q8 and take product
7892+
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
7893+
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
7894+
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
7895+
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
7896+
7897+
vl = 16;
7898+
7899+
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
7900+
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
7901+
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
7902+
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
7903+
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
7904+
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
7905+
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
7906+
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
7907+
7908+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
7909+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
7910+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
7911+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
7912+
7913+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
7914+
7915+
q6 += 64; qh += 32; q8 += 128; is=8;
7916+
7917+
}
7918+
7919+
sumf += d * sum_t;
7920+
7921+
}
7922+
7923+
*s = sumf;
7924+
77597925
#elif defined(__POWER9_VECTOR__)
77607926
const vector signed char lowMask = vec_splats((signed char)0xF);
77617927
const vector int v0 = vec_splats((int32_t)0);

0 commit comments

Comments
 (0)