@@ -6245,47 +6245,41 @@ template <int nrc_y>
62456245static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
62466246 GGML_ASSERT(nrc_x%8 == 0);
62476247 GGML_ASSERT(n%32 == 0);
6248- #ifndef HAVE_FANCY_SIMD
6248+ __m256i qx[2];
6249+ __m256i acc[nrc_y] = {};
6250+ float dy[nrc_y];
6251+ #ifdef HAVE_FANCY_SIMD
6252+ int32_t sy[nrc_y];
6253+ #else
6254+ __m256i sx[2];
62496255 auto m1 = _mm256_set1_epi16(1);
62506256#endif
6251- __m256i qx[4];
6252- __m256i sx[4];
6253- __m256i acc[nrc_y] = {};
6254- float dy[nrc_y];
62556257 const int8_t * q8y[nrc_y];
62566258 for (int iy = 0; iy < nrc_y; ++iy) {
62576259 auto dptr = (const float *)info.src1_row(iy);
62586260 dy[iy] = dptr[0];
6261+ #ifdef HAVE_FANCY_SIMD
6262+ auto iptr = (const int32_t *)(dptr+1);
6263+ sy[iy] = -127*iptr[0];
6264+ #endif
62596265 q8y[iy] = (const int8_t *)(dptr + 2);
62606266 }
62616267 for (int ix = 0; ix < nrc_x; ++ix) {
62626268 auto dx = (const float *)((const char *)vx + ix*bx);
62636269 auto q8x = (const int8_t *)(dx + 2);
6264- for (int i = 0; i < n/128; ++i) {
6265- for (int j = 0; j < 4; ++j) {
6266- qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 4*i + j);
6267- sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6268- }
6269- for (int iy = 0; iy < nrc_y; ++iy) {
6270- for (int j = 0; j < 4; ++j) {
6270+ for (int i = 0; i < n/64; ++i) {
6271+ for (int j = 0; j < 2; ++j) {
62716272#ifdef HAVE_FANCY_SIMD
6272- acc[iy ] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8( _mm256_loadu_si256((const __m256i *)q8y[iy] + 4 *i + j), qx[j] ));
6273+ qx[j ] = _mm256_add_epi8( _mm256_loadu_si256((const __m256i *)q8x + 2 *i + j), _mm256_set1_epi8(127 ));
62736274#else
6274- auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6275- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6276- #endif
6277- }
6278- }
6279- }
6280- for (int i = 2*(n/128); i < n/64; ++i) {
6281- for (int j = 0; j < 2; ++j) {
62826275 qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
62836276 sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
6277+ #endif
62846278 }
62856279 for (int iy = 0; iy < nrc_y; ++iy) {
62866280 for (int j = 0; j < 2; ++j) {
62876281#ifdef HAVE_FANCY_SIMD
6288- acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx [j], _mm256_sign_epi8( _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j] ));
6282+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx [j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
62896283#else
62906284 auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
62916285 acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
@@ -6294,11 +6288,15 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
62946288 }
62956289 }
62966290 if (int i = 2*(n/64); i < n/32) {
6291+ #ifdef HAVE_FANCY_SIMD
6292+ qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
6293+ #else
62976294 qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
62986295 sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6296+ #endif
62996297 for (int iy = 0; iy < nrc_y; ++iy) {
63006298#ifdef HAVE_FANCY_SIMD
6301- acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx [0], _mm256_sign_epi8( _mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0] ));
6299+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx [0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
63026300#else
63036301 auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
63046302 acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
@@ -6307,7 +6305,11 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
63076305 }
63086306 for (int iy = 0; iy < nrc_y; ++iy) {
63096307 auto sumi = hsum_i32_8(acc[iy]);
6310- info.store(ix, iy, dx[0]*dy[2*iy+0]*sumi);
6308+ #ifdef HAVE_FANCY_SIMD
6309+ info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
6310+ #else
6311+ info.store(ix, iy, dx[0]*dy[iy]*sumi);
6312+ #endif
63116313 acc[iy] = _mm256_setzero_si256();
63126314 }
63136315 }
0 commit comments