Skip to content

Commit 7979f85

Browse files
author
Iwan Kawrakow
committed
q8_KV: Better AVX2 gemm
1 parent 5388764 commit 7979f85

File tree

2 files changed

+74
-75
lines changed

2 files changed

+74
-75
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6172,73 +6172,22 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
61726172
}
61736173

61746174
template <int nrc_y>
6175-
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6175+
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61766176
GGML_ASSERT(nrc_x%8 == 0);
6177-
GGML_ASSERT(n%128 == 0);
6177+
GGML_ASSERT(n%32 == 0);
61786178
__m256i qx[4];
61796179
__m256i sx[4];
61806180
__m256i acc[nrc_y] = {};
6181-
float dy[2*nrc_y];
6181+
float dy[nrc_y];
61826182
const int8_t * q8y[nrc_y];
61836183
for (int iy = 0; iy < nrc_y; ++iy) {
61846184
auto dptr = (const float *)info.src1_row(iy);
6185-
dy[2*iy+0] = dptr[0];
6186-
dy[2*iy+1] = 127*dptr[1];
6185+
dy[iy] = dptr[0];
61876186
q8y[iy] = (const int8_t *)(dptr + 2);
61886187
}
61896188
for (int ix = 0; ix < nrc_x; ++ix) {
61906189
auto dx = (const float *)((const char *)vx + ix*bx);
61916190
auto q8x = (const int8_t *)(dx + 2);
6192-
//for (int i = 0; i < n/32; ++i) {
6193-
// //auto qx = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
6194-
// auto qx = _mm256_loadu_si256((const __m256i *)q8x + i);
6195-
// auto sx = _mm256_sign_epi8(qx, qx);
6196-
// for (int iy = 0; iy < nrc_y; ++iy) {
6197-
// //acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx, _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
6198-
// acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx));
6199-
// }
6200-
//}
6201-
////for (int iy = 0; iy < nrc_y; ++iy) {
6202-
//// int sumi = 0;
6203-
//// for (int j = 0; j < n; ++j) sumi += q8x[j]*q8y[iy][j];
6204-
//// info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
6205-
////}
6206-
//for (int i = 0; i < n/128; ++i) {
6207-
// for (int j = 0; j < 4; ++j) {
6208-
// qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j);
6209-
// qx[j] = _mm256_add_epi8(qx[j], _mm256_set1_epi8(127));
6210-
// }
6211-
// for (int iy = 0; iy < nrc_y; ++iy) {
6212-
// for (int j = 0; j < 4; ++j) {
6213-
// acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j));
6214-
// }
6215-
// }
6216-
//}
6217-
////for (int i = 2*(n/128); i < n/64; ++i) {
6218-
//// for (int j = 0; j < 2; ++j) {
6219-
//// qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
6220-
//// qx[j] = _mm256_add_epi8(qx[j], _mm256_set1_epi8(127));
6221-
//// }
6222-
//// for (int iy = 0; iy < nrc_y; ++iy) {
6223-
//// for (int j = 0; j < 2; ++j) {
6224-
//// acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
6225-
//// }
6226-
//// }
6227-
////}
6228-
////if (int i = 2*(n/64); i < n/32) {
6229-
//// qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
6230-
//// qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
6231-
//// for (int iy = 0; iy < nrc_y; ++iy) {
6232-
//// acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
6233-
//// }
6234-
////}
6235-
//// sum [dx * (qx_i - 128) * dy * qy_i] = dx*(dy*sum[qx_i * qy_i] - 128*dy*[sum qy_i]
6236-
//for (int iy = 0; iy < nrc_y; ++iy) {
6237-
// auto sumi = hsum_i32_8(acc[iy]);
6238-
// //info.store(ix, iy, dx[0]*(dy[2*iy+0]*sumi - dy[2*iy+1]));
6239-
// info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
6240-
// acc[iy] = _mm256_setzero_si256();
6241-
//}
62426191
for (int i = 0; i < n/128; ++i) {
62436192
for (int j = 0; j < 4; ++j) {
62446193
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j);
@@ -6250,24 +6199,24 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62506199
}
62516200
}
62526201
}
6253-
//for (int i = 2*(n/128); i < n/64; ++i) {
6254-
// for (int j = 0; j < 2; ++j) {
6255-
// qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
6256-
// sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6257-
// }
6258-
// for (int iy = 0; iy < nrc_y; ++iy) {
6259-
// for (int j = 0; j < 2; ++j) {
6260-
// acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6261-
// }
6262-
// }
6263-
//}
6264-
//if (int i = 2*(n/64); i < n/32) {
6265-
// qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
6266-
// sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6267-
// for (int iy = 0; iy < nrc_y; ++iy) {
6268-
// acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6269-
// }
6270-
//}
6202+
for (int i = 2*(n/128); i < n/64; ++i) {
6203+
for (int j = 0; j < 2; ++j) {
6204+
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
6205+
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6206+
}
6207+
for (int iy = 0; iy < nrc_y; ++iy) {
6208+
for (int j = 0; j < 2; ++j) {
6209+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6210+
}
6211+
}
6212+
}
6213+
if (int i = 2*(n/64); i < n/32) {
6214+
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
6215+
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6216+
for (int iy = 0; iy < nrc_y; ++iy) {
6217+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6218+
}
6219+
}
62716220
for (int iy = 0; iy < nrc_y; ++iy) {
62726221
auto sumi = hsum_i32_8(acc[iy]);
62736222
info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
@@ -6276,6 +6225,56 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62766225
}
62776226
}
62786227

6228+
template <int nrc_y>
6229+
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6230+
GGML_ASSERT(nrc_x%8 == 0);
6231+
GGML_ASSERT(n%32 == 0);
6232+
__m256i qx[4];
6233+
__m256i sx[4];
6234+
__m256i acc[nrc_y] = {};
6235+
float dy[nrc_y];
6236+
const int8_t * q8y[nrc_y];
6237+
for (int iy = 0; iy < nrc_y; ++iy) {
6238+
auto dptr = (const float *)info.src1_row(iy);
6239+
dy[iy] = dptr[0];
6240+
q8y[iy] = (const int8_t *)(dptr + 2);
6241+
}
6242+
const int8_t * q8x[4];
6243+
float dx[4];
6244+
for (int ix = 0; ix < nrc_x; ix += 4) {
6245+
for (int kx = 0; kx < 4; ++kx) {
6246+
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
6247+
dx[kx] = dptr[0];
6248+
q8x[kx] = (const int8_t *)(dptr + 2);
6249+
}
6250+
for (int i = 0; i < n/32; ++i) {
6251+
for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
6252+
auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
6253+
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
6254+
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
6255+
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
6256+
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6257+
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6258+
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6259+
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6260+
for (int iy = 0; iy < nrc_y; ++iy) {
6261+
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
6262+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6263+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6264+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6265+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6266+
}
6267+
}
6268+
auto scales_x = _mm_loadu_ps(dx);
6269+
for (int iy = 0; iy < nrc_y; ++iy) {
6270+
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
6271+
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
6272+
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
6273+
acc[iy] = _mm256_setzero_si256();
6274+
}
6275+
}
6276+
}
6277+
62796278
#ifdef __AVX512BF16__
62806279
template <int nrc_y>
62816280
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -9223,7 +9222,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
92239222
break;
92249223
case GGML_TYPE_Q8_KV:
92259224
assert (ne00 % 32 == 0);
9226-
mm.funcs[0] = mul_mat_q8_KV_q8_KV<1>;
9225+
mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>;
92279226
mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>;
92289227
mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>;
92299228
mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>;

ggml/src/iqk/iqk_quantize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2969,7 +2969,7 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
29692969
}
29702970
// TODO: merge this with the above template
29712971
void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
2972-
assert(k % kBlockSize == 0);
2972+
assert(k % 32 == 0);
29732973
auto dptr = (float *)vy;
29742974
auto q8 = (int8_t *)(dptr + 2);
29752975
#ifdef __AVX2__

0 commit comments

Comments
 (0)