@@ -6175,6 +6175,9 @@ template <int nrc_y>
61756175static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61766176 GGML_ASSERT(nrc_x%8 == 0);
61776177 GGML_ASSERT(n%32 == 0);
6178+ #ifndef HAVE_FANCY_SIMD
6179+ auto m1 = _mm256_set1_epi16(1);
6180+ #endif
61786181 __m256i qx[4];
61796182 __m256i sx[4];
61806183 __m256i acc[nrc_y] = {};
@@ -6195,7 +6198,12 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
61956198 }
61966199 for (int iy = 0; iy < nrc_y; ++iy) {
61976200 for (int j = 0; j < 4; ++j) {
6201+ #ifdef HAVE_FANCY_SIMD
61986202 acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6203+ #else
6204+ auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 4*i + j), qx[j]));
6205+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6206+ #endif
61996207 }
62006208 }
62016209 }
@@ -6206,15 +6214,25 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
62066214 }
62076215 for (int iy = 0; iy < nrc_y; ++iy) {
62086216 for (int j = 0; j < 2; ++j) {
6217+ #ifdef HAVE_FANCY_SIMD
62096218 acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6219+ #else
6220+ auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
6221+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6222+ #endif
62106223 }
62116224 }
62126225 }
62136226 if (int i = 2*(n/64); i < n/32) {
62146227 qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
62156228 sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
62166229 for (int iy = 0; iy < nrc_y; ++iy) {
6230+ #ifdef HAVE_FANCY_SIMD
62176231 acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6232+ #else
6233+ auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
6234+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, dot));
6235+ #endif
62186236 }
62196237 }
62206238 for (int iy = 0; iy < nrc_y; ++iy) {
@@ -6230,16 +6248,23 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62306248 GGML_ASSERT(nrc_x%8 == 0);
62316249 GGML_ASSERT(n%32 == 0);
62326250 __m256i qx[4];
6233- //__m256i sx[4];
6251+ #ifndef HAVE_FANCY_SIMD
6252+ __m256i sx[4];
6253+ auto m1 = _mm256_set1_epi16(1);
6254+ #endif
62346255 __m256i acc[nrc_y] = {};
62356256 float dy[nrc_y];
6257+ #ifdef HAVE_FANCY_SIMD
62366258 int32_t sy[nrc_y];
6259+ #endif
62376260 const int8_t * q8y[nrc_y];
62386261 for (int iy = 0; iy < nrc_y; ++iy) {
62396262 auto dptr = (const float *)info.src1_row(iy);
62406263 dy[iy] = dptr[0];
6264+ #ifdef HAVE_FANCY_SIMD
62416265 auto iptr = (const int32_t *)(dptr + 1);
62426266 sy[iy] = -127*iptr[0];
6267+ #endif
62436268 q8y[iy] = (const int8_t *)(dptr + 2);
62446269 }
62456270 const int8_t * q8x[4];
@@ -6256,35 +6281,43 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
62566281 auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
62576282 auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
62586283 auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
6259- //qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6260- //qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6261- //qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6262- //qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6284+ #ifdef HAVE_FANCY_SIMD
62636285 qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
62646286 qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
62656287 qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
62666288 qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
6289+ #else
6290+ qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
6291+ qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
6292+ qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
6293+ qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
6294+ #endif
62676295 for (int iy = 0; iy < nrc_y; ++iy) {
62686296 auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
6269- //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6270- //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6271- //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6272- //acc[iy] = _mm256_dpbusd_epi32(acc[iy], sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6297+ #ifdef HAVE_FANCY_SIMD
62736298 acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
62746299 acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
62756300 acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
62766301 acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
6302+ #else
6303+ auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6304+ auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6305+ auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6306+ auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6307+ auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
6308+ auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
6309+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
6310+ #endif
62776311 }
62786312 }
62796313 auto scales_x = _mm_loadu_ps(dx);
62806314 for (int iy = 0; iy < nrc_y; ++iy) {
62816315 auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
6316+ #ifdef HAVE_FANCY_SIMD
62826317 sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
6318+ #endif
62836319 auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
62846320 info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
6285- //auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+0]));
6286- //auto minus = _mm_mul_ps(scales_x, _mm_set1_ps(dy[2*iy+1]));
6287- //info.store(ix, iy, _mm_fmadd_ps(scale, _mm_cvtepi32_ps(sumi), minus));
62886321 acc[iy] = _mm256_setzero_si256();
62896322 }
62906323 }
0 commit comments