@@ -4587,7 +4587,252 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
45874587
45884588 const int nb = n / QK_K;
45894589
4590- #ifdef __ARM_NEON
4590+ #ifdef __ARM_FEATURE_SVE
4591+ const int vector_length = svcntb()*8;
4592+ const svuint8_t m3s = svdup_n_u8(0x3);
4593+ const svuint32_t m4s = svdup_n_u32(0xF);
4594+ const svint32_t vzero_sv = svdup_n_s32(0);
4595+ svfloat32_t acc_sum = svdup_n_f32(0);
4596+ svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
4597+
4598+ switch (vector_length) {
4599+ case 128:
4600+ for (int i = 0; i < nb; ++i) {
4601+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4602+ svfloat32_t d_broad = svdup_n_f32((float32_t)d);
4603+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
4604+ svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
4605+
4606+ const uint8_t * restrict q2 = x[i].qs;
4607+ const int8_t * restrict q8_sv = y[i].qs;
4608+ const uint8_t * restrict sc = x[i].scales;
4609+
4610+ svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
4611+ const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4612+
4613+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
4614+ const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4615+
4616+ svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
4617+ svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
4618+
4619+ const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
4620+
4621+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
4622+ const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4623+
4624+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
4625+ const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
4626+
4627+ q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
4628+ q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
4629+
4630+ svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
4631+
4632+ svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
4633+
4634+ acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
4635+
4636+ svint32_t sumi1 = svdup_n_s32(0);
4637+
4638+ {
4639+ const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
4640+ svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
4641+ svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4642+ const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
4643+
4644+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
4645+
4646+ const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
4647+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
4648+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4649+
4650+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
4651+
4652+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
4653+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4654+
4655+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
4656+
4657+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
4658+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4659+
4660+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
4661+
4662+
4663+ const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
4664+
4665+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
4666+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4667+
4668+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
4669+
4670+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
4671+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4672+
4673+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
4674+
4675+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
4676+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4677+
4678+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
4679+
4680+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
4681+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4682+
4683+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
4684+
4685+ //-------------------------------
4686+
4687+ q2 += 32;
4688+ const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
4689+ const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
4690+
4691+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
4692+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4693+
4694+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
4695+
4696+ const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
4697+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
4698+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4699+
4700+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
4701+
4702+
4703+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
4704+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4705+
4706+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
4707+
4708+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
4709+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4710+
4711+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
4712+
4713+
4714+ const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
4715+
4716+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
4717+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4718+
4719+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
4720+
4721+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
4722+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4723+
4724+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
4725+
4726+
4727+
4728+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
4729+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4730+
4731+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
4732+
4733+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
4734+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
4735+
4736+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
4737+ }
4738+ acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
4739+ }
4740+ *s = svaddv_f32(svptrue_b32(), acc_sum);
4741+ break;
4742+
4743+ case 256:
4744+ case 512:
4745+ for (int i = 0; i < nb; ++i) {
4746+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4747+ svfloat32_t d_broad = svdup_n_f32((float32_t)d);
4748+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
4749+ svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
4750+
4751+ const uint8_t * restrict q2 = x[i].qs;
4752+ const int8_t * restrict q8_sv = y[i].qs;
4753+ const uint8_t * restrict sc = x[i].scales;
4754+
4755+ const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
4756+ const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
4757+ const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
4758+ svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
4759+
4760+ const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
4761+ const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
4762+ const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
4763+
4764+ svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
4765+
4766+ svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
4767+
4768+ acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
4769+
4770+ svint32_t sumi1 = svdup_n_s32(0);
4771+
4772+ {
4773+ const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
4774+ svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
4775+ svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4776+
4777+ svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
4778+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4779+
4780+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
4781+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4782+
4783+ svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
4784+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
4785+
4786+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
4787+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4788+
4789+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
4790+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4791+
4792+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
4793+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4794+
4795+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
4796+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4797+
4798+ q2 += 32;
4799+
4800+ const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
4801+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
4802+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4803+
4804+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
4805+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4806+
4807+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
4808+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4809+
4810+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
4811+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4812+
4813+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
4814+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4815+
4816+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
4817+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
4818+
4819+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
4820+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
4821+
4822+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
4823+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
4824+ }
4825+ acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
4826+ }
4827+ *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
4828+ break;
4829+
4830+ default:
4831+ assert(false && "Unsupported vector length");
4832+ break;
4833+ }
4834+
4835+ #elif __ARM_NEON
45914836 const uint8x16_t m3 = vdupq_n_u8(0x3);
45924837 const uint8x16_t m4 = vdupq_n_u8(0xF);
45934838
0 commit comments