@@ -6230,13 +6230,16 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62306230 GGML_ASSERT(nrc_x%8 == 0);
62316231 GGML_ASSERT(n%32 == 0);
62326232 __m256i qx[4];
6233- __m256i sx[4];
6233+ // __m256i sx[4];
62346234 __m256i acc[nrc_y] = {};
62356235 float dy[nrc_y];
6236+ int32_t sy[nrc_y];
62366237 const int8_t * q8y[nrc_y];
62376238 for (int iy = 0; iy < nrc_y; ++iy) {
62386239 auto dptr = (const float *)info.src1_row(iy);
62396240 dy[iy] = dptr[0];
6241+ auto iptr = (const int32_t *)(dptr + 1);
6242+ sy[iy] = -127*iptr[0];
62406243 q8y[iy] = (const int8_t *)(dptr + 2);
62416244 }
62426245 const int8_t * q8x[4];
@@ -6253,23 +6256,35 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62536256 auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
62546257 auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
62556258 auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
6256- qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6257- qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6258- qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6259- qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6259+ //qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6260+ //qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6261+ //qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6262+ //qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6263+ qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
6264+ qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
6265+ qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
6266+ qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
62606267 for (int iy = 0; iy < nrc_y; ++iy) {
62616268 auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
6262- acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6263- acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6264- acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6265- acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6269+ //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6270+ //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6271+ //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6272+ //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6273+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
6274+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
6275+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
6276+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
62666277 }
62676278 }
62686279 auto scales_x = _mm_loadu_ps(dx);
62696280 for (int iy = 0; iy < nrc_y; ++iy) {
62706281 auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
6282+ sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
62716283 auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
62726284 info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
6285+ //auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0]));
6286+ //auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1]));
6287+ //info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus));
62736288 acc[iy] = _mm256_setzero_si256();
62746289 }
62756290 }
0 commit comments