Skip to content

Commit 0d7885f

Browse files
author
Iwan Kawrakow
committed
q8_KV: Better Zen4 gemm
We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s.
1 parent 7979f85 commit 0d7885f

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6230,13 +6230,16 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62306230
GGML_ASSERT(nrc_x%8 == 0);
62316231
GGML_ASSERT(n%32 == 0);
62326232
__m256i qx[4];
6233-
__m256i sx[4];
6233+
//__m256i sx[4];
62346234
__m256i acc[nrc_y] = {};
62356235
float dy[nrc_y];
6236+
int32_t sy[nrc_y];
62366237
const int8_t * q8y[nrc_y];
62376238
for (int iy = 0; iy < nrc_y; ++iy) {
62386239
auto dptr = (const float *)info.src1_row(iy);
62396240
dy[iy] = dptr[0];
6241+
auto iptr = (const int32_t *)(dptr + 1);
6242+
sy[iy] = -127*iptr[0];
62406243
q8y[iy] = (const int8_t *)(dptr + 2);
62416244
}
62426245
const int8_t * q8x[4];
@@ -6253,23 +6256,35 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62536256
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
62546257
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
62556258
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
6256-
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6257-
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6258-
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6259-
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6259+
//qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6260+
//qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6261+
//qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6262+
//qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6263+
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
6264+
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
6265+
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
6266+
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
62606267
for (int iy = 0; iy < nrc_y; ++iy) {
62616268
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
6262-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6263-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6264-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6265-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6269+
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6270+
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6271+
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6272+
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6273+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
6274+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
6275+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
6276+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
62666277
}
62676278
}
62686279
auto scales_x = _mm_loadu_ps(dx);
62696280
for (int iy = 0; iy < nrc_y; ++iy) {
62706281
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
6282+
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
62716283
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
62726284
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
6285+
//auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0]));
6286+
//auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1]));
6287+
//info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus));
62736288
acc[iy] = _mm256_setzero_si256();
62746289
}
62756290
}

ggml/src/iqk/iqk_quantize.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3010,7 +3010,8 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
30103010
_mm256_storeu_si256((__m256i *)q8, i0);
30113011
q8 += 32;
30123012
}
3013-
dptr[1] = dptr[0] * hsum_i32_8(isum);
3013+
auto iptr = (int32_t *)(dptr + 1);
3014+
iptr[0] = hsum_i32_8(isum);
30143015
#elif defined __ARM_NEON
30153016
int32x4_t ival[8];
30163017
auto vmax = vdupq_n_f32(0.f);
@@ -3037,7 +3038,8 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
30373038
q8 += 8;
30383039
}
30393040
}
3040-
dptr[1] = dptr[0] * vaddvq_s32(isum);
3041+
auto iptr = (int32_t *)(dptr + 1);
3042+
iptr[0] = vaddvq_s32(isum);
30413043
#else
30423044
float amax = 0;
30433045
for (int j = 0; j < k; ++j) {
@@ -3056,7 +3058,8 @@ void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
30563058
q8[i] = nearest_int(id*x[i]);
30573059
isum += q8[i];
30583060
}
3059-
dptr[1] = dptr[0]*isum;
3061+
auto iptr = (int32_t *)(dptr + 1);
3062+
iptr[0] = isum;
30603063
#endif
30613064
}
30623065
}

0 commit comments

Comments
 (0)