@@ -6172,73 +6172,22 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
61726172}
61736173
61746174template <int nrc_y>
6175- static void mul_mat_q8_KV_q8_KV (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6175+ static void mul_mat_q8_KV_q8_KV_1 (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61766176 GGML_ASSERT(nrc_x%8 == 0);
6177- GGML_ASSERT(n%128 == 0);
6177+ GGML_ASSERT(n%32 == 0);
61786178 __m256i qx[4];
61796179 __m256i sx[4];
61806180 __m256i acc[nrc_y] = {};
6181- float dy[2* nrc_y];
6181+ float dy[nrc_y];
61826182 const int8_t * q8y[nrc_y];
61836183 for (int iy = 0; iy < nrc_y; ++iy) {
61846184 auto dptr = (const float *)info.src1_row(iy);
6185- dy[2*iy+0] = dptr[0];
6186- dy[2*iy+1] = 127*dptr[1];
6185+ dy[iy] = dptr[0];
61876186 q8y[iy] = (const int8_t *)(dptr + 2);
61886187 }
61896188 for (int ix = 0; ix < nrc_x; ++ix) {
61906189 auto dx = (const float *)((const char *)vx + ix*bx);
61916190 auto q8x = (const int8_t *)(dx + 2);
6192- //for (int i = 0; i < n/32; ++i) {
6193- // //auto qx = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
6194- // auto qx = _mm256_loadu_si256((const __m256i *)q8x + i);
6195- // auto sx = _mm256_sign_epi8(qx, qx);
6196- // for (int iy = 0; iy < nrc_y; ++iy) {
6197- // //acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx, _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
6198- // acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx));
6199- // }
6200- //}
6201- ////for (int iy = 0; iy < nrc_y; ++iy) {
6202- //// int sumi = 0;
6203- //// for (int j = 0; j < n; ++j) sumi += q8x[j]*q8y[iy][j];
6204- //// info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
6205- ////}
6206- //for (int i = 0; i < n/128; ++i) {
6207- // for (int j = 0; j < 4; ++j) {
6208- // qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j);
6209- // qx[j] = _mm256_add_epi8(qx[j], _mm256_set1_epi8(127));
6210- // }
6211- // for (int iy = 0; iy < nrc_y; ++iy) {
6212- // for (int j = 0; j < 4; ++j) {
6213- // acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j));
6214- // }
6215- // }
6216- //}
6217- ////for (int i = 2*(n/128); i < n/64; ++i) {
6218- //// for (int j = 0; j < 2; ++j) {
6219- //// qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
6220- //// qx[j] = _mm256_add_epi8(qx[j], _mm256_set1_epi8(127));
6221- //// }
6222- //// for (int iy = 0; iy < nrc_y; ++iy) {
6223- //// for (int j = 0; j < 2; ++j) {
6224- //// acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
6225- //// }
6226- //// }
6227- ////}
6228- ////if (int i = 2*(n/64); i < n/32) {
6229- //// qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
6230- //// qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
6231- //// for (int iy = 0; iy < nrc_y; ++iy) {
6232- //// acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
6233- //// }
6234- ////}
6235- //// sum [dx * (qx_i - 128) * dy * qy_i] = dx*(dy*sum[qx_i * qy_i] - 128*dy*[sum qy_i]
6236- //for (int iy = 0; iy < nrc_y; ++iy) {
6237- // auto sumi = hsum_i32_8(acc[iy]);
6238- // //info.store(ix, iy, dx[0]*(dy[2*iy+0]*sumi - dy[2*iy+1]));
6239- // info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
6240- // acc[iy] = _mm256_setzero_si256();
6241- //}
62426191 for (int i = 0; i < n/128; ++i) {
62436192 for (int j = 0; j < 4; ++j) {
62446193 qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j);
@@ -6250,24 +6199,24 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62506199 }
62516200 }
62526201 }
6253- // for (int i = 2*(n/128); i < n/64; ++i) {
6254- // for (int j = 0; j < 2; ++j) {
6255- // qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
6256- // sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6257- // }
6258- // for (int iy = 0; iy < nrc_y; ++iy) {
6259- // for (int j = 0; j < 2; ++j) {
6260- // acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6261- // }
6262- // }
6263- // }
6264- // if (int i = 2*(n/64); i < n/32) {
6265- // qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
6266- // sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6267- // for (int iy = 0; iy < nrc_y; ++iy) {
6268- // acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6269- // }
6270- // }
6202+ for (int i = 2*(n/128); i < n/64; ++i) {
6203+ for (int j = 0; j < 2; ++j) {
6204+ qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
6205+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6206+ }
6207+ for (int iy = 0; iy < nrc_y; ++iy) {
6208+ for (int j = 0; j < 2; ++j) {
6209+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6210+ }
6211+ }
6212+ }
6213+ if (int i = 2*(n/64); i < n/32) {
6214+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
6215+ sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6216+ for (int iy = 0; iy < nrc_y; ++iy) {
6217+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6218+ }
6219+ }
62716220 for (int iy = 0; iy < nrc_y; ++iy) {
62726221 auto sumi = hsum_i32_8(acc[iy]);
62736222 info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
@@ -6276,6 +6225,56 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62766225 }
62776226}
62786227
6228+ template <int nrc_y>
6229+ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6230+ GGML_ASSERT(nrc_x%8 == 0);
6231+ GGML_ASSERT(n%32 == 0);
6232+ __m256i qx[4];
6233+ __m256i sx[4];
6234+ __m256i acc[nrc_y] = {};
6235+ float dy[nrc_y];
6236+ const int8_t * q8y[nrc_y];
6237+ for (int iy = 0; iy < nrc_y; ++iy) {
6238+ auto dptr = (const float *)info.src1_row(iy);
6239+ dy[iy] = dptr[0];
6240+ q8y[iy] = (const int8_t *)(dptr + 2);
6241+ }
6242+ const int8_t * q8x[4];
6243+ float dx[4];
6244+ for (int ix = 0; ix < nrc_x; ix += 4) {
6245+ for (int kx = 0; kx < 4; ++kx) {
6246+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
6247+ dx[kx] = dptr[0];
6248+ q8x[kx] = (const int8_t *)(dptr + 2);
6249+ }
6250+ for (int i = 0; i < n/32; ++i) {
6251+ for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
6252+ auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
6253+ auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
6254+ auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
6255+ auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
6256+ qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6257+ qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6258+ qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6259+ qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6260+ for (int iy = 0; iy < nrc_y; ++iy) {
6261+ auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
6262+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6263+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6264+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6265+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6266+ }
6267+ }
6268+ auto scales_x = _mm_loadu_ps(dx);
6269+ for (int iy = 0; iy < nrc_y; ++iy) {
6270+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
6271+ auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
6272+ info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
6273+ acc[iy] = _mm256_setzero_si256();
6274+ }
6275+ }
6276+ }
6277+
62796278#ifdef __AVX512BF16__
62806279template <int nrc_y>
62816280static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -9223,7 +9222,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
92239222 break;
92249223 case GGML_TYPE_Q8_KV:
92259224 assert (ne00 % 32 == 0);
9226- mm.funcs[0] = mul_mat_q8_KV_q8_KV <1>;
9225+ mm.funcs[0] = mul_mat_q8_KV_q8_KV_1 <1>;
92279226 mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>;
92289227 mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>;
92299228 mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>;
0 commit comments