Skip to content

Commit 58c13d0

Browse files
author
Iwan Kawrakow
committed
q8_KV: ARM_NEON
We get PP-512 = 167 t/s for L3-8B without interleaving! We do the interleaving on the fly, so I wonder if this could be done for other quants as well.
1 parent 10168ab commit 58c13d0

File tree

2 files changed

+109
-29
lines changed

2 files changed

+109
-29
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13665,6 +13665,81 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
1366513665
}
1366613666
}
1366713667

13668+
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
13669+
GGML_ASSERT(n%32 == 0);
13670+
int32x4_t acc[4] = {};
13671+
auto dptr = (const float *)info.src1_row(0);
13672+
const float dy = dptr[0];
13673+
auto q8y = (const int8_t *)(dptr + 2);
13674+
for (int ix = 0; ix < nrc_x; ++ix) {
13675+
auto dx = (const float *)((const char *)vx + ix*bx);
13676+
auto q8x = (const int8_t *)(dx + 2);
13677+
for (int i = 0; i < n/64; ++i) {
13678+
auto qx = vld1q_s8_x4(q8x + 64*i);
13679+
for (int j = 0; j < 4; ++j) {
13680+
acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
13681+
}
13682+
}
13683+
if (int i = 2*(n/64); i < n/32) {
13684+
auto qx = vld1q_s8_x2(q8x + 32*i);
13685+
for (int j = 0; j < 2; ++j) {
13686+
acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
13687+
}
13688+
}
13689+
acc[0] = vaddq_s32(acc[0], acc[1]);
13690+
acc[2] = vaddq_s32(acc[2], acc[3]);
13691+
acc[0] = vaddq_s32(acc[0], acc[2]);
13692+
info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
13693+
acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
13694+
}
13695+
}
13696+
13697+
template <int nrc_y>
13698+
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
13699+
GGML_ASSERT(nrc_x%4 == 0);
13700+
GGML_ASSERT(n%16 == 0);
13701+
int8x16_t qx[4];
13702+
int32x4_t acc[nrc_y] = {};
13703+
float dy[nrc_y];
13704+
const int8_t * q8y[nrc_y];
13705+
for (int iy = 0; iy < nrc_y; ++iy) {
13706+
auto dptr = (const float *)info.src1_row(iy);
13707+
dy[iy] = dptr[0];
13708+
q8y[iy] = (const int8_t *)(dptr + 2);
13709+
}
13710+
const int8_t * q8x[4];
13711+
float dx[4];
13712+
for (int ix = 0; ix < nrc_x; ix += 4) {
13713+
for (int kx = 0; kx < 4; ++kx) {
13714+
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
13715+
dx[kx] = dptr[0];
13716+
q8x[kx] = (const int8_t *)(dptr + 2);
13717+
}
13718+
for (int i = 0; i < n/16; ++i) {
13719+
for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
13720+
auto row01 = vtrnq_s32(qx[0], qx[1]);
13721+
auto row23 = vtrnq_s32(qx[2], qx[3]);
13722+
qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
13723+
qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
13724+
qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
13725+
qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
13726+
for (int iy = 0; iy < nrc_y; ++iy) {
13727+
auto y = vld1q_s8(q8y[iy] + 16*i);
13728+
acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
13729+
acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
13730+
acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
13731+
acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
13732+
}
13733+
}
13734+
auto scales_x = vld1q_f32(dx);
13735+
for (int iy = 0; iy < nrc_y; ++iy) {
13736+
auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
13737+
info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
13738+
acc[iy] = vdupq_n_s32(0);
13739+
}
13740+
}
13741+
}
13742+
1366813743
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1366913744
GGML_ASSERT(nrc_x%4 == 0);
1367013745
Q8<1, block_q8_0_x4> q8(info);
@@ -14241,6 +14316,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
1424114316
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
1424214317
expected_Btype = GGML_TYPE_Q8_KR8;
1424314318
break;
14319+
case GGML_TYPE_Q8_KV:
14320+
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV);
14321+
m.funcs[0] = mul_mat_q8_KV_q8_KV_1;
14322+
m.func16 = mul_mat_q8_KV_q8_KV<16>;
14323+
expected_Btype = GGML_TYPE_Q8_KV;
14324+
break;
1424414325
case GGML_TYPE_IQ2_K_R4:
1424514326
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
1424614327
expected_Btype = GGML_TYPE_Q8_K;
@@ -14605,9 +14686,8 @@ struct HelperQ8KV final : public BaseHelper<step> {
1460514686
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
1460614687
auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
1460714688
#ifdef __aarch64__
14608-
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
14609-
int ii = j%QK8_0;
14610-
auto qs = vld1_s8_x2(dl->qs + ii);
14689+
auto vd = F16::set1(q8->d);
14690+
auto qs = vld1_s8_x2(q8->qs + 8*i);
1461114691
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
1461214692
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
1461314693
#else
@@ -14859,28 +14939,26 @@ struct HelperQ8KVR8 : public BaseHelper<step> {
1485914939
_mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
1486014940
#elif defined __ARM_NEON
1486114941
// TODO
14862-
for (int l = 0; l < 2; ++l) {
14863-
m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
14864-
m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
14865-
m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
14866-
m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
14867-
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
14868-
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
14869-
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14870-
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14871-
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14872-
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14873-
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
14874-
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
14875-
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14876-
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14877-
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14878-
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14879-
vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
14880-
vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
14881-
vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
14882-
vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
14883-
}
14942+
m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
14943+
m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
14944+
m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
14945+
m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
14946+
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
14947+
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
14948+
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14949+
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14950+
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14951+
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14952+
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
14953+
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
14954+
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14955+
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14956+
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14957+
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14958+
vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
14959+
vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
14960+
vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
14961+
vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
1488414962
#else
1488514963
// TODO
1488614964
for (int l = 0; l < 4; ++l) {

ggml/src/iqk/iqk_quantize.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3016,19 +3016,21 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
30163016
int32x4_t ival[8];
30173017
auto vmax = vdupq_n_f32(0.f);
30183018
for (int j = 0; j < k; j += 4) {
3019-
vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(xb + j)));
3019+
vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j)));
30203020
}
30213021
auto smax = vmaxvq_f32(vmax);
30223022
if (!smax) {
30233023
dptr[0] = dptr[1] = 0;
30243024
std::memset(q8, 0, k*sizeof(int8_t));
30253025
return;
30263026
}
3027-
auto vid = vdupq_n_f32(127/smax);
3027+
dptr[0] = smax/127;
3028+
auto vid = vdupq_n_f32(1/dptr[0]);
30283029
auto isum = vdupq_n_s32(0);
30293030
for (int ib = 0; ib < k/32; ++ib) {
3031+
auto xb = x + 32*ib;
30303032
for (int k = 0; k < 8; ++k) {
3031-
auto val = vld1q_f32(xb + 32*ib + 4*k);
3033+
auto val = vld1q_f32(xb + 4*k);
30323034
ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid));
30333035
isum = vaddq_s32(isum, ival[k]);
30343036
}
@@ -6549,7 +6551,7 @@ void quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
65496551
iqk_quantize_row_q8_KV(x, vy, k);
65506552
}
65516553

6552-
void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) {
6554+
void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) {
65536555
quantize_row_q8_KV(x, y, k);
65546556
}
65556557

0 commit comments

Comments
 (0)