Skip to content

Commit e1bdd14

Browse files
taronaeoggerganov
authored andcommitted
ggml : activate s390x simd for Q3_K (llama/13301)
Signed-off-by: Aaron Teo <[email protected]>
1 parent 7fa8bb3 commit e1bdd14

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6590,7 +6590,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
65906590
}
65916591

65926592
*s = hsum_float_8(acc);
6593+
#elif defined(__VXE__) || defined(__VXE2__)
6594+
uint32_t aux[3];
6595+
uint32_t utmp[4];
6596+
6597+
const int32x4_t v_z = vec_splat_s32(0);
6598+
const uint8x16_t v_3m = vec_splat_u8(0x03);
6599+
6600+
const uint8x16_t v_0c = vec_splat_u8(1);
6601+
const uint8x16_t v_1c = vec_sl(v_0c, 1);
6602+
const uint8x16_t v_2c = vec_sl(v_0c, 2);
6603+
const uint8x16_t v_3c = vec_sl(v_0c, 3);
6604+
6605+
uint8x16_t q3h[4];
6606+
uint8x16_t q3b[2];
6607+
int8x16_t q3bytes[4];
6608+
int8x16_t q8bytes[4];
6609+
uint8x16_t qhbits[2];
6610+
6611+
float sum = 0;
6612+
6613+
for (int i = 0; i < nb; ++i) {
6614+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
65936615

6616+
const uint8_t * restrict x0l = x[i].qs;
6617+
const uint8_t * restrict x0h = x[i].hmask;
6618+
const int8_t * restrict y0 = y[i].qs;
6619+
6620+
qhbits[0] = vec_xl(0 , x0h);
6621+
qhbits[1] = vec_xl(16, x0h);
6622+
6623+
int32_t isum = 0;
6624+
6625+
memcpy(aux, x[i].scales, 12);
6626+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6627+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6628+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6629+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6630+
6631+
int8_t * scale = (int8_t *)utmp;
6632+
for (int j = 0; j < 16; ++j) scale[j] -= 32;
6633+
6634+
for (int j = 0; j < QK_K/128; ++j) {
6635+
int32x4_t isum0, isum1, isum2, isum3;
6636+
6637+
q3b[0] = vec_xl(0 , x0l);
6638+
q3b[1] = vec_xl(16, x0l);
6639+
x0l += 32;
6640+
6641+
q8bytes[0] = vec_xl(0 , y0);
6642+
q8bytes[1] = vec_xl(16 , y0);
6643+
q8bytes[2] = vec_xl(32 , y0);
6644+
q8bytes[3] = vec_xl(48 , y0);
6645+
q8bytes[4] = vec_xl(64 , y0);
6646+
q8bytes[5] = vec_xl(80 , y0);
6647+
q8bytes[6] = vec_xl(96 , y0);
6648+
q8bytes[7] = vec_xl(112, y0);
6649+
y0 += 128;
6650+
6651+
q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
6652+
q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
6653+
q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
6654+
q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
6655+
6656+
q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
6657+
q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
6658+
q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
6659+
q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
6660+
6661+
isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
6662+
isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
6663+
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
6664+
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
6665+
6666+
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6667+
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6668+
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6669+
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6670+
6671+
scale += 4;
6672+
6673+
q3h[0] = vec_andc(v_2c, qhbits[0]);
6674+
q3h[1] = vec_andc(v_2c, qhbits[1]);
6675+
q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
6676+
q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
6677+
6678+
q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
6679+
q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
6680+
q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
6681+
q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
6682+
6683+
isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
6684+
isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
6685+
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
6686+
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
6687+
6688+
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6689+
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6690+
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6691+
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6692+
6693+
scale += 4;
6694+
6695+
if (j == 0) {
6696+
qhbits[0] = vec_sr(qhbits[0], 4);
6697+
qhbits[1] = vec_sr(qhbits[1], 4);
6698+
}
6699+
}
6700+
6701+
sum += d * isum;
6702+
}
6703+
6704+
*s = sum;
65946705
#else
65956706
// scalar version
65966707
// This function is written like this so the compiler can manage to vectorize most of it

0 commit comments

Comments
 (0)