Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions ggml/src/iqk/iqk_gemm_kquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX

// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static void mul_mat_q8_k_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
#ifndef HAVE_FANCY_SIMD
Expand Down Expand Up @@ -1858,6 +1858,67 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
}
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
#ifdef HAVE_FANCY_SIMD
if constexpr (nrc_y < 2) {
mul_mat_q8_k_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
return;
}
GGML_ASSERT(nrc_x%16 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
__m512 acc[nrc_y] = {};
__m512i isum[nrc_y] = {};
__m512i qx[4];
auto m127 = _mm512_set1_epi8(127);
for (int ix = 0; ix < nrc_x; ix += 16) {
const block_q8_k_r8 * iq8l = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
const block_q8_k_r8 * iq8h = (const block_q8_k_r8 *)((const char *)vx + (ix+8)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4ph = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)iq8h[ibl].d), _mm_loadu_si128((const __m128i *)iq8l[ibl].d));
auto d4 = _mm512_cvtph_ps(d4ph);
for (int ib = 0; ib < QK_K/16; ++ib) {
qx[0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+0)),
_mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+0), 1);;
qx[1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+1)),
_mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+1), 1);;
qx[2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+2)),
_mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+2), 1);;
qx[3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq8l[ibl].qs+4*ib+3)),
_mm256_loadu_si256((const __m256i *)iq8h[ibl].qs+4*ib+3), 1);;
qx[0] = _mm512_add_epi8(qx[0], m127);
qx[1] = _mm512_add_epi8(qx[1], m127);
qx[2] = _mm512_add_epi8(qx[2], m127);
qx[3] = _mm512_add_epi8(qx[3], m127);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
auto y256 = MM256_SET_M128I(y128, y128);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
isum[iy] = _mm512_dpbusd_epi32(isum[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
}
auto m4 = _mm512_mul_ps(d4, _mm512_set1_ps(-127.f));
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm512_fmadd_ps(d4y, _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
acc[iy] = _mm512_fmadd_ps(m4, _mm512_set1_ps(q8.y[iy][ibl].sum), acc[iy]);
isum[iy] = _mm512_setzero_si512();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, _mm512_castps512_ps256(acc[iy]));
info.store(ix+8, iy, _mm512_extractf32x8_ps(acc[iy], 1));
acc[iy] = _mm512_setzero_ps();
}
}
#else
mul_mat_q8_k_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
#endif
}

template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
Expand Down Expand Up @@ -2671,9 +2732,9 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
break;
case GGML_TYPE_Q8_K_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_q8_k_r8_q8_k<16>;
#endif
//#ifdef HAVE_FANCY_SIMD
// func16 = mul_mat_q8_k_r8_q8_k<16>;
//#endif
break;
case GGML_TYPE_Q8_KV:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ struct MulMat {
case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q8_1: return 8;
case GGML_TYPE_Q8_K_R8:
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_BF16_R16: return 16;
Expand Down