@@ -5986,7 +5986,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
59865986
59875987 uint32_t utmp[4];
59885988
5989- #ifdef __ARM_NEON
5989+ #ifdef __ARM_FEATURE_SVE
5990+ float sumf = 0;
5991+ for (int i = 0; i < nb; ++i) {
5992+
5993+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5994+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5995+
5996+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
5997+
5998+ memcpy(utmp, x[i].scales, K_SCALE_SIZE);
5999+
6000+ uint32x2_t mins8 = { 0 };
6001+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
6002+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
6003+
6004+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6005+ utmp[0] &= kmask1;
6006+
6007+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
6008+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
6009+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
6010+ sumf -= dmin * vaddvq_s32(prod);
6011+
6012+ const uint8_t * scales = (const uint8_t *)utmp;
6013+
6014+ const uint8_t * restrict q4 = x[i].qs;
6015+ const int8_t * restrict q8 = y[i].qs;
6016+
6017+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
6018+ const svuint8_t m4b = svdup_n_u8(0xf);
6019+ const svint32_t mzero = svdup_n_s32(0);
6020+ svint32_t sumi1 = svdup_n_s32(0);
6021+ svint32_t sumi1_1 = svdup_n_s32(0);
6022+ svint32_t sumi1_2 = svdup_n_s32(0);
6023+ svint32_t sumi2 = svdup_n_s32(0);
6024+ svint32_t sumi2_1 = svdup_n_s32(0);
6025+ svint32_t sumi2_2 = svdup_n_s32(0);
6026+ switch (vector_length) {
6027+ case 128:
6028+ {
6029+ for (int j = 0; j < QK_K/64; ++j) {
6030+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
6031+ svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6032+ sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6033+ q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
6034+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6035+ sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6036+
6037+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
6038+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6039+ sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6040+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
6041+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6042+ sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6043+ q4 += 32;
6044+ }
6045+ sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
6046+ sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
6047+ sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
6048+ } break;
6049+ case 256:
6050+ case 512:
6051+ {
6052+ for (int j = 0; j < QK_K/64; ++j) {
6053+ const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
6054+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
6055+ svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
6056+ sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6057+
6058+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
6059+ q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
6060+ sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6061+ }
6062+ sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
6063+ } break;
6064+ default:
6065+ assert(false && "Unsupported vector length");
6066+ break;
6067+ }
6068+ }
6069+ *s = sumf;
6070+ #elif defined __ARM_NEON
59906071 const uint8x16_t m4b = vdupq_n_u8(0xf);
59916072 const int32x4_t mzero = vdupq_n_s32(0);
59926073
@@ -7756,6 +7837,91 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
77567837 }
77577838 *s = sumf;
77587839
7840+ #elif defined __riscv_v_intrinsic
7841+
7842+ float sumf = 0;
7843+ for (int i = 0; i < nb; ++i) {
7844+
7845+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7846+
7847+ const uint8_t * restrict q6 = x[i].ql;
7848+ const uint8_t * restrict qh = x[i].qh;
7849+ const int8_t * restrict q8 = y[i].qs;
7850+
7851+ const int8_t * restrict scale = x[i].scales;
7852+
7853+ size_t vl;
7854+
7855+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
7856+
7857+ int sum_t = 0;
7858+ int is = 0;
7859+
7860+ for (int j = 0; j < QK_K/128; ++j) {
7861+
7862+ vl = 32;
7863+
7864+ // load qh
7865+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
7866+
7867+ // load Q6
7868+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
7869+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
7870+
7871+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
7872+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
7873+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
7874+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
7875+
7876+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
7877+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
7878+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
7879+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
7880+
7881+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
7882+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
7883+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
7884+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
7885+
7886+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
7887+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
7888+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
7889+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
7890+
7891+ // load Q8 and take product
7892+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
7893+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
7894+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
7895+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
7896+
7897+ vl = 16;
7898+
7899+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
7900+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
7901+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
7902+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
7903+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
7904+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
7905+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
7906+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
7907+
7908+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
7909+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
7910+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
7911+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
7912+
7913+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
7914+
7915+ q6 += 64; qh += 32; q8 += 128; is=8;
7916+
7917+ }
7918+
7919+ sumf += d * sum_t;
7920+
7921+ }
7922+
7923+ *s = sumf;
7924+
77597925#elif defined(__POWER9_VECTOR__)
77607926 const vector signed char lowMask = vec_splats((signed char)0xF);
77617927 const vector int v0 = vec_splats((int32_t)0);
0 commit comments