From c462c5bdf6667473322807c89edb3a6f928480dc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 14 Jul 2025 18:45:55 +0300 Subject: [PATCH] q8_k_r8: AVX512 version On my 7950X this is slower than what we have on the main branch --- ggml/src/iqk/iqk_gemm_kquants.cpp | 69 +++++++++++++++++++++++++++++-- ggml/src/iqk/iqk_mul_mat.cpp | 4 +- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index c44cf41bc..054e4e0ec 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -1792,7 +1792,7 @@ template void set_functions(std::array -static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_k_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); #ifndef HAVE_FANCY_SIMD @@ -1858,6 +1858,67 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } } +template +static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +#ifdef HAVE_FANCY_SIMD + if constexpr (nrc_y < 2) { + mul_mat_q8_k_r8_q8_k_avx2(n, vx, bx, info, nrc_x); + return; + } + GGML_ASSERT(nrc_x%16 == 0); + Q8 q8(info); + int nbl = n / QK_K; + __m512 acc[nrc_y] = {}; + __m512i isum[nrc_y] = {}; + __m512i qx[4]; + auto m127 = _mm512_set1_epi8(127); + for (int ix = 0; ix < nrc_x; ix += 16) { + const block_q8_k_r8 * iq8l = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx); + const block_q8_k_r8 * iq8h = (const block_q8_k_r8 *)((const char *)vx + (ix+8)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4ph = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)iq8h[ibl].d), _mm_loadu_si128((const __m128i *)iq8l[ibl].d)); + auto d4 = _mm512_cvtph_ps(d4ph); + for (int ib = 0; ib < QK_K/16; ++ib) { + qx[0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+0)), + _mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+0), 1);; + qx[1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+1)), + _mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+1), 1);; + qx[2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+2)), + _mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+2), 1);; + qx[3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+3)), + _mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+3), 1);; + qx[0] = _mm512_add_epi8(qx[0], m127); + qx[1] = _mm512_add_epi8(qx[1], m127); + qx[2] = _mm512_add_epi8(qx[2], m127); + qx[3] = _mm512_add_epi8(qx[3], m127); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + auto m4 = _mm512_mul_ps(d4, _mm512_set1_ps(-127.f)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4y = _mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))); + acc[iy] = _mm512_fmadd_ps(d4y, _mm512_cvtepi32_ps(isum[iy]), acc[iy]); + acc[iy] = _mm512_fmadd_ps(m4, _mm512_set1_ps(q8.y[iy][ibl].sum), acc[iy]); + isum[iy] = _mm512_setzero_si512(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, _mm512_castps512_ps256(acc[iy])); + info.store(ix+8, iy, _mm512_extractf32x8_ps(acc[iy], 1)); + acc[iy] = _mm512_setzero_ps(); + } + } +#else + mul_mat_q8_k_r8_q8_k_avx2(n, vx, bx, info, nrc_x); +#endif +} template static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -2671,9 +2732,9 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array; -#endif +//#ifdef HAVE_FANCY_SIMD +// func16 = mul_mat_q8_k_r8_q8_k<16>; +//#endif break; case GGML_TYPE_Q8_KV: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0054f6cb3..912818951 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -337,8 +337,8 @@ struct MulMat { case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_KV_R8: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q8_K_R8: return 8; + case GGML_TYPE_Q8_1: return 8; + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: case GGML_TYPE_BF16_R16: return 16;