@@ -6596,7 +6596,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
65966596 }
65976597
65986598 *s = hsum_float_8(acc);
6599+ #elif defined(__VXE__) || defined(__VXE2__)
6600+ uint32_t aux[3];
6601+ uint32_t utmp[4];
6602+
6603+ const int32x4_t v_z = vec_splat_s32(0);
6604+ const uint8x16_t v_3m = vec_splat_u8(0x03);
6605+
6606+ const uint8x16_t v_0c = vec_splat_u8(1);
6607+ const uint8x16_t v_1c = vec_sl(v_0c, 1);
6608+ const uint8x16_t v_2c = vec_sl(v_0c, 2);
6609+ const uint8x16_t v_3c = vec_sl(v_0c, 3);
6610+
6611+ uint8x16_t q3h[4];
6612+ uint8x16_t q3b[2];
6613+ int8x16_t q3bytes[4];
6614+ int8x16_t q8bytes[4];
6615+ uint8x16_t qhbits[2];
6616+
6617+ float sum = 0;
6618+
6619+ for (int i = 0; i < nb; ++i) {
6620+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
65996621
6622+ const uint8_t * restrict x0l = x[i].qs;
6623+ const uint8_t * restrict x0h = x[i].hmask;
6624+ const int8_t * restrict y0 = y[i].qs;
6625+
6626+ qhbits[0] = vec_xl(0 , x0h);
6627+ qhbits[1] = vec_xl(16, x0h);
6628+
6629+ int32_t isum = 0;
6630+
6631+ memcpy(aux, x[i].scales, 12);
6632+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6633+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6634+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6635+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6636+
6637+ int8_t * scale = (int8_t *)utmp;
6638+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6639+
6640+ for (int j = 0; j < QK_K/128; ++j) {
6641+ int32x4_t isum0, isum1, isum2, isum3;
6642+
6643+ q3b[0] = vec_xl(0 , x0l);
6644+ q3b[1] = vec_xl(16, x0l);
6645+ x0l += 32;
6646+
6647+ q8bytes[0] = vec_xl(0 , y0);
6648+ q8bytes[1] = vec_xl(16 , y0);
6649+ q8bytes[2] = vec_xl(32 , y0);
6650+ q8bytes[3] = vec_xl(48 , y0);
6651+ q8bytes[4] = vec_xl(64 , y0);
6652+ q8bytes[5] = vec_xl(80 , y0);
6653+ q8bytes[6] = vec_xl(96 , y0);
6654+ q8bytes[7] = vec_xl(112, y0);
6655+ y0 += 128;
6656+
6657+ q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
6658+ q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
6659+ q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
6660+ q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
6661+
6662+ q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
6663+ q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
6664+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
6665+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
6666+
6667+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
6668+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
6669+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
6670+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
6671+
6672+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6673+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6674+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6675+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6676+
6677+ scale += 4;
6678+
6679+ q3h[0] = vec_andc(v_2c, qhbits[0]);
6680+ q3h[1] = vec_andc(v_2c, qhbits[1]);
6681+ q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
6682+ q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
6683+
6684+ q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
6685+ q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
6686+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
6687+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
6688+
6689+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
6690+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
6691+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
6692+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
6693+
6694+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6695+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6696+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6697+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6698+
6699+ scale += 4;
6700+
6701+ if (j == 0) {
6702+ qhbits[0] = vec_sr(qhbits[0], 4);
6703+ qhbits[1] = vec_sr(qhbits[1], 4);
6704+ }
6705+ }
6706+
6707+ sum += d * isum;
6708+ }
6709+
6710+ *s = sum;
66006711#else
66016712 // scalar version
66026713 // This function is written like this so the compiler can manage to vectorize most of it
0 commit comments