@@ -5112,7 +5112,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
51125112
51135113 const int nb = n / QK_K;
51145114
5115- #ifdef __ARM_NEON
5115+ #if defined(__ARM_FEATURE_SVE)
5116+
5117+ uint32_t utmp[4];
5118+
5119+ const int8_t m32 = 32;
5120+ const int vector_length = svcntb()*8;
5121+ const svuint8_t m3b_sv = svdup_n_u8(0x3);
5122+ const svint32_t vzero_sv = svdup_n_s32(0);
5123+
5124+ const svuint8_t m0_sv = svdup_n_u8(1);
5125+ const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
5126+ const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
5127+ const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
5128+ svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
5129+
5130+ float sum = 0;
5131+
5132+ for (int i = 0; i < nb; ++i) {
5133+
5134+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5135+
5136+ const uint8_t * restrict q3_sv = x[i].qs;
5137+ const uint8_t * restrict qh_sv = x[i].hmask;
5138+ const int8_t * restrict q8_sv = y[i].qs;
5139+
5140+ // Set up scales
5141+ uint32_t * aux = &x[i].scales;
5142+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
5143+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
5144+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
5145+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
5146+
5147+ int8_t * scale = (int8_t *)utmp;
5148+
5149+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
5150+
5151+ switch (vector_length) {
5152+ case 128:
5153+ {
5154+ svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
5155+ svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
5156+ svuint8_t q3h_sv;
5157+
5158+ svint32_t sumi1_1 = svdup_n_s32(0);
5159+ svint8_t q3bytes_sv;
5160+
5161+ for (int j = 0; j < QK_K/128; ++j) {
5162+
5163+ const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
5164+ const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
5165+ svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5166+ svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5167+
5168+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
5169+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5170+
5171+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
5172+
5173+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
5174+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5175+
5176+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
5177+
5178+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5179+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5180+
5181+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
5182+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5183+
5184+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
5185+
5186+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
5187+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5188+
5189+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
5190+
5191+
5192+ scale += 4;
5193+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5194+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5195+
5196+ q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
5197+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5198+
5199+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
5200+
5201+ q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
5202+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5203+
5204+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
5205+
5206+
5207+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5208+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5209+
5210+ q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
5211+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5212+
5213+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
5214+
5215+ q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
5216+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5217+
5218+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
5219+
5220+ if (j == 0) {
5221+ qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
5222+ qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
5223+ }
5224+
5225+ scale += 4;
5226+ }
5227+
5228+ sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
5229+ } break;
5230+ case 256:
5231+ case 512:
5232+ {
5233+ svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
5234+ svuint8_t q3h_sv;
5235+
5236+ svint32_t sumi1_1 = svdup_n_s32(0);
5237+ svint8_t q3bytes_sv;
5238+
5239+ for (int j = 0; j < QK_K/128; ++j) {
5240+
5241+ const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
5242+ svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5243+ svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5244+
5245+ q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
5246+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5247+
5248+
5249+ svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
5250+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
5251+
5252+ q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
5253+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5254+
5255+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
5256+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
5257+
5258+ scale += 4;
5259+ q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5260+ q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5261+
5262+ q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
5263+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5264+
5265+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
5266+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
5267+
5268+ q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
5269+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5270+
5271+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
5272+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
5273+
5274+ if (j == 0) {
5275+ qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
5276+ }
5277+
5278+ scale += 4;
5279+ }
5280+
5281+ sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
5282+ } break;
5283+ default:
5284+ assert(false && "Unsupported vector length");
5285+ break;
5286+ }
5287+ }
5288+ *s = sum;
5289+
5290+ #elif __ARM_NEON
51165291
51175292 uint32_t aux[3];
51185293 uint32_t utmp[4];
0 commit comments