@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
81588158
81598159 const int nb = n / QK_K;
81608160
8161- #ifdef __ARM_NEON
8161+ #ifdef __ARM_FEATURE_SVE
8162+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163+ float sum = 0;
8164+ svuint8_t m4b = svdup_n_u8(0xf);
8165+ svint32_t vzero = svdup_n_s32(0);
8166+ svuint8_t mone = svdup_n_u8(0x30);
8167+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169+
8170+ for (int i = 0; i < nb; ++i) {
8171+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172+
8173+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176+
8177+ const int8_t * GGML_RESTRICT scale = x[i].scales;
8178+
8179+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184+ const svint64_t prod = svdup_n_s64(0);
8185+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186+ svdot_s64(prod, q8sums_2, q6scales_2)));
8187+ int32_t isum = 0;
8188+
8189+ switch (vector_length) {
8190+ case 128:
8191+ {
8192+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194+ svint32_t isum_tmp = svdup_n_s32(0);
8195+ for (int j = 0; j < QK_K/128; ++j) {
8196+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198+ qh += 32;
8199+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203+ q6 += 64;
8204+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208+ q8 += 64;
8209+
8210+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222+
8223+ scale += 4;
8224+ q8bytes_1 = svld1_s8(pg8_16, q8);
8225+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228+ q8 += 64;
8229+
8230+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242+ scale += 4;
8243+ }
8244+ isum += svaddv_s32(pg32_4, isum_tmp);
8245+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246+ }
8247+ break;
8248+ case 256:
8249+ case 512:
8250+ {
8251+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254+ svint32_t isum_tmp = svdup_n_s32(0);
8255+ for (int j = 0; j < QK_K/128; j++) {
8256+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257+ qh += 32;
8258+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260+ q6 += 64;
8261+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265+ q8 += 128;
8266+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274+
8275+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291+
8292+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296+ scale += 8;
8297+ }
8298+ isum += svaddv_s32(pg32_8, isum_tmp);
8299+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300+ }
8301+ break;
8302+ default:
8303+ assert(false && "Unsupported vector length");
8304+ break;
8305+ }
8306+ }
8307+
8308+ *s = sum;
8309+
8310+ #elif __ARM_NEON
81628311 float sum = 0;
81638312
81648313 const uint8x16_t m4b = vdupq_n_u8(0xF);
0 commit comments