Skip to content

Commit 1ecea16

Browse files
author
Iwan Kawrakow
committed
q8_KV: slightly faster gemv on Zen4
1 parent 7f4ec2f commit 1ecea16

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6245,47 +6245,41 @@ template <int nrc_y>
62456245
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
62466246
GGML_ASSERT(nrc_x%8 == 0);
62476247
GGML_ASSERT(n%32 == 0);
6248-
#ifndef HAVE_FANCY_SIMD
6248+
__m256i qx[2];
6249+
__m256i acc[nrc_y] = {};
6250+
float dy[nrc_y];
6251+
#ifdef HAVE_FANCY_SIMD
6252+
int32_t sy[nrc_y];
6253+
#else
6254+
__m256i sx[2];
62496255
auto m1 = _mm256_set1_epi16(1);
62506256
#endif
6251-
__m256i qx[4];
6252-
__m256i sx[4];
6253-
__m256i acc[nrc_y] = {};
6254-
float dy[nrc_y];
62556257
const int8_t * q8y[nrc_y];
62566258
for (int iy = 0; iy < nrc_y; ++iy) {
62576259
auto dptr = (const float *)info.src1_row(iy);
62586260
dy[iy] = dptr[0];
6261+
#ifdef HAVE_FANCY_SIMD
6262+
auto iptr = (const int32_t *)(dptr+1);
6263+
sy[iy] = -127*iptr[0];
6264+
#endif
62596265
q8y[iy] = (const int8_t *)(dptr + 2);
62606266
}
62616267
for (int ix = 0; ix < nrc_x; ++ix) {
62626268
auto dx = (const float *)((const char *)vx + ix*bx);
62636269
auto q8x = (const int8_t *)(dx + 2);
6264-
for (int i = 0; i < n/128; ++i) {
6265-
for (int j = 0; j < 4; ++j) {
6266-
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j);
6267-
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6268-
}
6269-
for (int iy = 0; iy < nrc_y; ++iy) {
6270-
for (int j = 0; j < 4; ++j) {
6270+
for (int i = 0; i < n/64; ++i) {
6271+
for (int j = 0; j < 2; ++j) {
62716272
#ifdef HAVE_FANCY_SIMD
6272-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6273+
qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
62736274
#else
6274-
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6275-
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6276-
#endif
6277-
}
6278-
}
6279-
}
6280-
for (int i = 2*(n/128); i < n/64; ++i) {
6281-
for (int j = 0; j < 2; ++j) {
62826275
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
62836276
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6277+
#endif
62846278
}
62856279
for (int iy = 0; iy < nrc_y; ++iy) {
62866280
for (int j = 0; j < 2; ++j) {
62876281
#ifdef HAVE_FANCY_SIMD
6288-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6282+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
62896283
#else
62906284
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
62916285
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
@@ -6294,11 +6288,15 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
62946288
}
62956289
}
62966290
if (int i = 2*(n/64); i < n/32) {
6291+
#ifdef HAVE_FANCY_SIMD
6292+
qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
6293+
#else
62976294
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
62986295
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6296+
#endif
62996297
for (int iy = 0; iy < nrc_y; ++iy) {
63006298
#ifdef HAVE_FANCY_SIMD
6301-
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6299+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
63026300
#else
63036301
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
63046302
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
@@ -6307,7 +6305,11 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
63076305
}
63086306
for (int iy = 0; iy < nrc_y; ++iy) {
63096307
auto sumi = hsum_i32_8(acc[iy]);
6310-
info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
6308+
#ifdef HAVE_FANCY_SIMD
6309+
info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
6310+
#else
6311+
info.store(ix, iy, dx[0]*dy[iy]*sumi);
6312+
#endif
63116313
acc[iy] = _mm256_setzero_si256();
63126314
}
63136315
}

0 commit comments

Comments
 (0)