Skip to content

Commit a4ffe2e

Browse files
author
Iwan Kawrakow
committed
q8_KV: AVX2 gemm/gemv
We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr.
1 parent 0d7885f commit a4ffe2e

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6175,6 +6175,9 @@ template <int nrc_y>
61756175
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61766176
GGML_ASSERT(nrc_x%8 == 0);
61776177
GGML_ASSERT(n%32 == 0);
6178+
#ifndef HAVE_FANCY_SIMD
6179+
auto m1 = _mm256_set1_epi16(1);
6180+
#endif
61786181
__m256i qx[4];
61796182
__m256i sx[4];
61806183
__m256i acc[nrc_y] = {};
@@ -6195,7 +6198,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
61956198
}
61966199
for (int iy = 0; iy < nrc_y; ++iy) {
61976200
for (int j = 0; j < 4; ++j) {
6201+
#ifdef HAVE_FANCY_SIMD
61986202
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6203+
#else
6204+
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6205+
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6206+
#endif
61996207
}
62006208
}
62016209
}
@@ -6206,15 +6214,25 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
62066214
}
62076215
for (int iy = 0; iy < nrc_y; ++iy) {
62086216
for (int j = 0; j < 2; ++j) {
6217+
#ifdef HAVE_FANCY_SIMD
62096218
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6219+
#else
6220+
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6221+
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6222+
#endif
62106223
}
62116224
}
62126225
}
62136226
if (int i = 2*(n/64); i < n/32) {
62146227
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
62156228
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
62166229
for (int iy = 0; iy < nrc_y; ++iy) {
6230+
#ifdef HAVE_FANCY_SIMD
62176231
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6232+
#else
6233+
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6234+
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6235+
#endif
62186236
}
62196237
}
62206238
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -6230,16 +6248,23 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62306248
GGML_ASSERT(nrc_x%8 == 0);
62316249
GGML_ASSERT(n%32 == 0);
62326250
__m256i qx[4];
6233-
//__m256i sx[4];
6251+
#ifndef HAVE_FANCY_SIMD
6252+
__m256i sx[4];
6253+
auto m1 = _mm256_set1_epi16(1);
6254+
#endif
62346255
__m256i acc[nrc_y] = {};
62356256
float dy[nrc_y];
6257+
#ifdef HAVE_FANCY_SIMD
62366258
int32_t sy[nrc_y];
6259+
#endif
62376260
const int8_t * q8y[nrc_y];
62386261
for (int iy = 0; iy < nrc_y; ++iy) {
62396262
auto dptr = (const float *)info.src1_row(iy);
62406263
dy[iy] = dptr[0];
6264+
#ifdef HAVE_FANCY_SIMD
62416265
auto iptr = (const int32_t *)(dptr + 1);
62426266
sy[iy] = -127*iptr[0];
6267+
#endif
62436268
q8y[iy] = (const int8_t *)(dptr + 2);
62446269
}
62456270
const int8_t * q8x[4];
@@ -6256,35 +6281,43 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62566281
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
62576282
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
62586283
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
6259-
//qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6260-
//qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6261-
//qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6262-
//qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6284+
#ifdef HAVE_FANCY_SIMD
62636285
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
62646286
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
62656287
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
62666288
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
6289+
#else
6290+
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6291+
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6292+
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6293+
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6294+
#endif
62676295
for (int iy = 0; iy < nrc_y; ++iy) {
62686296
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
6269-
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6270-
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6271-
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6272-
//acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6297+
#ifdef HAVE_FANCY_SIMD
62736298
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
62746299
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
62756300
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
62766301
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
6302+
#else
6303+
auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6304+
auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6305+
auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6306+
auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6307+
auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
6308+
auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
6309+
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
6310+
#endif
62776311
}
62786312
}
62796313
auto scales_x = _mm_loadu_ps(dx);
62806314
for (int iy = 0; iy < nrc_y; ++iy) {
62816315
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
6316+
#ifdef HAVE_FANCY_SIMD
62826317
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
6318+
#endif
62836319
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
62846320
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
6285-
//auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0]));
6286-
//auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1]));
6287-
//info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus));
62886321
acc[iy] = _mm256_setzero_si256();
62896322
}
62906323
}

0 commit comments

Comments
 (0)