From 62d8dd932b5a7724a6cb78fa80d5aeccc8fe8779 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 1 Jun 2025 18:07:10 +0300 Subject: [PATCH 1/4] If available, use bf16 for iq2_kt gemm/gemv With that, we get PP-512 = 234 t/s, so prompt processing is now in the low range of row-interleaved quants. --- ggml/src/ggml.c | 2 + ggml/src/iqk/iqk_gemm_ktquants.cpp | 130 ++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 6 +- 3 files changed, 133 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5bb75d322..907abd5ed 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1585,6 +1585,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq2_kt_q8_k, #ifdef __ARM_NEON .vec_dot_type = GGML_TYPE_F16, +#elif defined __AVX512BF16__ + .vec_dot_type = GGML_TYPE_BF16, #else .vec_dot_type = GGML_TYPE_F32, #endif diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index bc7bcf8b6..5ace1ce45 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -51,7 +51,28 @@ struct Trellis1 { const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); const __m256i mask1 = _mm256_set1_epi32(kmask); const __m256i mask2 = _mm256_set1_epi32(km32); - +#ifdef __AVX512BF16__ + const __m512i shuf1 = load_shuffle(0); + const __m512i shuf2 = load_shuffle(1); + static __m512i load_shuffle(int i) { + static const int32_t k_idx[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return _mm512_loadu_si512((const __m512i *)k_idx + i); + } + inline __m256 gen8(__m256i i8) const { + auto v = _mm512_cvtph_ps(i8); + v = _mm512_add_ps(_mm512_permutexvar_ps(shuf1, v), _mm512_permutexvar_ps(shuf2, v)); + return _mm512_castps512_ps256(v); + } + inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale) const { + auto v1 = _mm512_cvtph_ps(i8_1); + auto v2 = _mm512_cvtph_ps(i8_2); + auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2); + auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2); + auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2)); + return __m256i(_mm512_cvtneps_pbh(v)); + } +#endif inline __m256i next8(uint32_t val) const { auto mval = _mm256_set1_epi32(val); auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); @@ -97,6 +118,93 @@ struct Trellis2 { } }; +#ifdef __AVX512BF16__ +template +void mul_mat_iq2_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + union { __m256 vec; float val[8]; } s_helper; + + constexpr int k_acc = 2 * nrc_y; + __m256 accd[k_acc]; + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_cvtepi32_ps(s32); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = __m256bh(trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale)); + auto xval2 = __m256bh(trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K + 32*ib + 0))); + auto y2 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K + 32*ib + 16))); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], y1, xval1); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], y2, xval2); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, d*hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1]))); + } + } +} + +void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + union { __m256 vec; float val[8]; } s_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto vd = _mm256_set1_ps(*dptr * 31.75f * 1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale); + auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale); + _mm256_storeu_si256((__m256i *)(y + i*QK_K + 32*ib + 0), xval1); + _mm256_storeu_si256((__m256i *)(y + i*QK_K + 32*ib + 16), xval2); + } + } + y += stride_y; + } +} +#endif + void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -446,12 +554,22 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { - return false; - } + if (ne00%QK_K != 0) return false; func16 = nullptr; +#ifdef __AVX512BF16__ + if (typeA == GGML_TYPE_IQ2_KT) { + if (typeB != GGML_TYPE_BF16) return false; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_BF16_T, kernels); + return true; + } +#endif + + if (ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } + switch (typeA) { case GGML_TYPE_IQ2_KT: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F32_T, kernels); @@ -472,7 +590,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_BF16 : type; +#else case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; +#endif case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; default: break; @@ -606,7 +610,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: - return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; + return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: From 9890618db4180cc613381547a34258010b1c7e3a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 2 Jun 2025 07:14:54 +0300 Subject: [PATCH 2/4] If available, use bf16 for iq3_kt gemm/gemv With that, we get PP-512 = 233 t/s. --- ggml/src/ggml.c | 2 + ggml/src/iqk/iqk_gemm_ktquants.cpp | 123 ++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 3 +- 3 files changed, 126 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 907abd5ed..9e6c70b1e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1604,6 +1604,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq3_kt_q8_k, #ifdef __ARM_NEON .vec_dot_type = GGML_TYPE_F16, +#elif defined __AVX512BF16__ + .vec_dot_type = GGML_TYPE_BF16, #else .vec_dot_type = GGML_TYPE_F32, #endif diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 5ace1ce45..37265003f 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -64,12 +64,17 @@ struct Trellis1 { v = _mm512_add_ps(_mm512_permutexvar_ps(shuf1, v), _mm512_permutexvar_ps(shuf2, v)); return _mm512_castps512_ps256(v); } + //template inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale) const { auto v1 = _mm512_cvtph_ps(i8_1); auto v2 = _mm512_cvtph_ps(i8_2); auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2); auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2); auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2)); + //if constexpr (is_abs) { + // v = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), v); + //} + //v = _mm512_mul_ps(scale, v); return __m256i(_mm512_cvtneps_pbh(v)); } #endif @@ -423,6 +428,116 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +#ifdef __AVX512BF16__ +void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + __m256i all_signs[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + if (dptr[0] < 0) { + printf("Oops: row scale is %g\n", dptr[0]); + GGML_ABORT("Fatal error"); + } + auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32)); + for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j))); + auto mask = _mm256_set1_epi16(0x01); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale); + auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale); + auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000)); + auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000)); + auto x1 = _mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff))); + auto x2 = _mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff))); + _mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+ 0), x1); + _mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+16), x2); + mask = _mm256_slli_epi16(mask, 1); + } + } + y += stride_y; + } +} + +template +void mul_mat_iq3_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + __m256i all_signs[2]; + + __m256 accd[2*nrc_y]; + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + if (dptr[0] < 0) { + printf("Oops: row scale is %g\n", dptr[0]); + GGML_ABORT("Fatal error"); + } + auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32)); + for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j))); + auto mask = _mm256_set1_epi16(0x01); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale); + auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale); + auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000)); + auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000)); + mask = _mm256_slli_epi16(mask, 1); + auto x1 = __m256bh(_mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff)))); + auto x2 = __m256bh(_mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+ 0))); + auto y2 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+16))); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], y1, x1); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], y2, x2); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1]))); + } + } +} +#endif + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -564,6 +679,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_BF16 : type; + case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type; #else case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; -#endif case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; +#endif case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; default: break; } From 0715919fc031751a58bf8bcb114d4a0dbc5decef Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 2 Jun 2025 11:17:18 +0300 Subject: [PATCH 3/4] BF16 for iq4_kt --- ggml/src/ggml.c | 2 + ggml/src/iqk/iqk_gemm_ktquants.cpp | 185 +++++++++++++++++++++++++++-- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 3 files changed, 180 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9e6c70b1e..d9262bed6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1623,6 +1623,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq4_kt_q8_k, #ifdef __ARM_NEON .vec_dot_type = GGML_TYPE_F16, +#elif defined __AVX512BF16__ + .vec_dot_type = GGML_TYPE_BF16, #else .vec_dot_type = GGML_TYPE_F32, #endif diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 37265003f..883ce3005 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -28,6 +28,7 @@ inline float trellis_gen(uint32_t& val, uint32_t* s) { return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]); } +template struct Trellis1 { constexpr static uint32_t kmask = 0x8fff8fff; constexpr static uint32_t km32 = 0x3b603b60; @@ -47,8 +48,8 @@ struct Trellis1 { constexpr static uint32_t kb6 = kb5*ka+kb; constexpr static uint32_t ka7 = ka6*ka; constexpr static uint32_t kb7 = kb6*ka+kb; - const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7); - const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); + const __m256i mka = is4 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7); + const __m256i mkb = is4 ? _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3) : _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); const __m256i mask1 = _mm256_set1_epi32(kmask); const __m256i mask2 = _mm256_set1_epi32(km32); #ifdef __AVX512BF16__ @@ -64,25 +65,31 @@ struct Trellis1 { v = _mm512_add_ps(_mm512_permutexvar_ps(shuf1, v), _mm512_permutexvar_ps(shuf2, v)); return _mm512_castps512_ps256(v); } - //template inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale) const { auto v1 = _mm512_cvtph_ps(i8_1); auto v2 = _mm512_cvtph_ps(i8_2); auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2); auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2); auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2)); - //if constexpr (is_abs) { - // v = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), v); - //} - //v = _mm512_mul_ps(scale, v); return __m256i(_mm512_cvtneps_pbh(v)); } + inline __m256i gen8bh(uint32_t val1, uint32_t val2, __m512 scale) const { + return gen8bh(next8(val1), next8(val2), scale); + } + inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale) const { + return gen8bh(next8(val1, val2), next8(val3, val4), scale); + } #endif inline __m256i next8(uint32_t val) const { auto mval = _mm256_set1_epi32(val); auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); return _mm256_xor_si256(_mm256_and_si256(mres, mask1), mask2); } + inline __m256i next8(uint32_t val1, uint32_t val2) const { + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_xor_si256(_mm256_and_si256(mres, _mm256_set1_epi32(kmask)), _mm256_set1_epi32(km32)); + } }; inline __m256 trellis_gen8(__m256i i8) { @@ -593,7 +600,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis1 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; @@ -665,6 +672,163 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +#ifdef __AVX512BF16__ +template +void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); +/* + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = 2 * nrc_y; + + __m256 accd[k_acc]; + const ggml_bf16_t * y[nrc_y]; + float row_sum[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + //auto sum = _mm256_setzero_ps(); + //for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i)); + row_sum[iy] = 0; //hsum_float_8(sum); + } + + uint32_t val[8]; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto dav = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = _mm512_set1_ps(s_helper.val[ib+0]); + auto scale2 = _mm512_set1_ps(s_helper.val[ib+4]); + for (int j = 0; j < 4; j += 2) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + val[0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 << 12) & 0x7000) + o_helper.val[ib+0]; + val[1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 << 9) & 0x7000) + o_helper.val[ib+0]; + val[2] = ql[8*ib+2*j+ 2] + ((qh[8*ib+2*j+2] << 8) & 0xf00) + ((sh1 << 6) & 0x7000) + o_helper.val[ib+0]; + val[3] = ql[8*ib+2*j+ 3] + ((qh[8*ib+2*j+3] << 8) & 0xf00) + ((sh1 << 3) & 0x7000) + o_helper.val[ib+0]; + auto xval1 = trellis.gen8bh(val[0], val[1], val[2], val[3], scale1); + val[4] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 << 12) & 0x7000) + o_helper.val[ib+4]; + val[5] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 << 9) & 0x7000) + o_helper.val[ib+4]; + val[6] = ql[8*ib+2*j+34] + ((qh[8*ib+2*j+2] << 4) & 0xf00) + ((sh2 << 6) & 0x7000) + o_helper.val[ib+4]; + val[7] = ql[8*ib+2*j+35] + ((qh[8*ib+2*j+3] << 4) & 0xf00) + ((sh2 << 3) & 0x7000) + o_helper.val[ib+4]; + auto xval2 = trellis.gen8bh(val[4], val[5], val[6], val[7], scale2); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+ 0)); + auto y2 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+128)); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], __m256bh(y1), __m256bh(xval1)); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], __m256bh(y2), __m256bh(xval2)); + } + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])) + dav*row_sum[iy]); + } + } +*/ + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; + + __m256 accd[2*nrc_y]; + const ggml_bf16_t * y[nrc_y]; + float row_sum[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + //row_sum[iy] = 0; + auto sum = _mm256_setzero_ps(); + auto one = _mm512_cvtneps_pbh(_mm512_set1_ps(1.f)); + for (int i = 0; i < n/16; ++i) sum = _mm256_dpbf16_ps(sum, one, __m256bh(_mm256_loadu_si256((const __m256i *)y[iy] + i))); + row_sum[iy] = hsum_float_8(sum); + } + + union { __m512i vec[2]; uint16_t val[64]; } h_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto dav = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0)); + auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1)); + //const uint32_t * qh = (const uint32_t *)(ql + kNumGroups); + auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups))); + h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00))); + h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00))); + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + + // qh[(Q::kNg*ib + j)%(kNumGroups/2)] -> qh[(8*ib+j)%32], j = 0...7 + // ib = 0 -> 0....7 (uint8_t) -> 0, 1 (uint32_t), shift = 0 + // ib = 1 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 0 + // ib = 2 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 0 + // ib = 3 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 0 + // ib = 4 -> 0....7 (uint8_t) -> 1, 1 (uint32_t), shift = 4 + // ib = 5 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 4 + // ib = 6 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 4 + // ib = 7 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 4 + uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib]; + uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib]; + uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib]; + uint32_t val4 = h_helper.val[8*ib+3] + ((shb[ib] >> 5) & 0x7000) + o_helper.val[ib]; + auto xval1 = trellis.gen8bh(val1, val2, val3, val4, scale); + + val1 = h_helper.val[8*ib+4] + ((shb[ib] >> 8) & 0x7000) + o_helper.val[ib]; + val2 = h_helper.val[8*ib+5] + ((shb[ib] >> 11) & 0x7000) + o_helper.val[ib]; + val3 = h_helper.val[8*ib+6] + ((shb[ib] >> 14) & 0x7000) + o_helper.val[ib]; + val4 = h_helper.val[8*ib+7] + ((shb[ib] >> 17) & 0x7000) + o_helper.val[ib]; + auto xval2 = trellis.gen8bh(val1, val2, val3, val4, scale); + + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+ 0)); + auto y2 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+16)); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], __m256bh(y1), __m256bh(xval1)); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], __m256bh(y2), __m256bh(xval2)); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])) + dav*row_sum[iy]); + } + } +} +#endif + } // namespace bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { @@ -684,6 +848,11 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; #endif - case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; + //case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; default: break; } #else From 061d064b21c5c55266f00cb03f8583b2f187ee60 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 2 Jun 2025 11:59:20 +0300 Subject: [PATCH 4/4] If available, use bf16 for iq4_kt gemm/gemv With that, we get PP-512 = 240 t/s. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 106 +++++++++++++---------------- ggml/src/iqk/iqk_mul_mat.cpp | 3 +- 2 files changed, 48 insertions(+), 61 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 883ce3005..93edc0abf 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -73,12 +73,23 @@ struct Trellis1 { auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2)); return __m256i(_mm512_cvtneps_pbh(v)); } + inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale, __m512 offset) const { + auto v1 = _mm512_cvtph_ps(i8_1); + auto v2 = _mm512_cvtph_ps(i8_2); + auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2); + auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2); + auto v = _mm512_fmadd_ps(scale, _mm512_add_ps(vs1, vs2), offset); + return __m256i(_mm512_cvtneps_pbh(v)); + } inline __m256i gen8bh(uint32_t val1, uint32_t val2, __m512 scale) const { return gen8bh(next8(val1), next8(val2), scale); } inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale) const { return gen8bh(next8(val1, val2), next8(val3, val4), scale); } + inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale, __m512 offset) const { + return gen8bh(next8(val1, val2), next8(val3, val4), scale, offset); + } #endif inline __m256i next8(uint32_t val) const { auto mval = _mm256_set1_epi32(val); @@ -673,10 +684,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } #ifdef __AVX512BF16__ -template -void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); -/* const int nb = n/QK_K; constexpr int kNumGroups = 64; @@ -685,67 +694,54 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; - constexpr int k_acc = 2 * nrc_y; - - __m256 accd[k_acc]; - const ggml_bf16_t * y[nrc_y]; - float row_sum[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) { - y[iy] = (const ggml_bf16_t *)info.src1_row(iy); - //auto sum = _mm256_setzero_ps(); - //for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i)); - row_sum[iy] = 0; //hsum_float_8(sum); - } - - uint32_t val[8]; + union { __m512i vec[2]; uint16_t val[64]; } h_helper; for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); - auto dav = dptr[1]; + auto dav = _mm512_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); - for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); const uint32_t * shb = x[i].qs; - const uint8_t * ql = (const uint8_t *)(shb + 8); - const uint8_t * qh = ql + kNumGroups; + const uint8_t * ql = (const uint8_t *)(shb + 8); + auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0)); + auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1)); + auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups))); + h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00))); + h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00))); auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); - for (int ib = 0; ib < 4; ++ib) { - auto scale1 = _mm512_set1_ps(s_helper.val[ib+0]); - auto scale2 = _mm512_set1_ps(s_helper.val[ib+4]); - for (int j = 0; j < 4; j += 2) { - const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - val[0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 << 12) & 0x7000) + o_helper.val[ib+0]; - val[1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 << 9) & 0x7000) + o_helper.val[ib+0]; - val[2] = ql[8*ib+2*j+ 2] + ((qh[8*ib+2*j+2] << 8) & 0xf00) + ((sh1 << 6) & 0x7000) + o_helper.val[ib+0]; - val[3] = ql[8*ib+2*j+ 3] + ((qh[8*ib+2*j+3] << 8) & 0xf00) + ((sh1 << 3) & 0x7000) + o_helper.val[ib+0]; - auto xval1 = trellis.gen8bh(val[0], val[1], val[2], val[3], scale1); - val[4] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 << 12) & 0x7000) + o_helper.val[ib+4]; - val[5] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 << 9) & 0x7000) + o_helper.val[ib+4]; - val[6] = ql[8*ib+2*j+34] + ((qh[8*ib+2*j+2] << 4) & 0xf00) + ((sh2 << 6) & 0x7000) + o_helper.val[ib+4]; - val[7] = ql[8*ib+2*j+35] + ((qh[8*ib+2*j+3] << 4) & 0xf00) + ((sh2 << 3) & 0x7000) + o_helper.val[ib+4]; - auto xval2 = trellis.gen8bh(val[4], val[5], val[6], val[7], scale2); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y1 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+ 0)); - auto y2 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+8*j+128)); - accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], __m256bh(y1), __m256bh(xval1)); - accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], __m256bh(y2), __m256bh(xval2)); - } - } + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + + uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib]; + uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib]; + uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib]; + uint32_t val4 = h_helper.val[8*ib+3] + ((shb[ib] >> 5) & 0x7000) + o_helper.val[ib]; + auto xval1 = trellis.gen8bh(val1, val2, val3, val4, scale, dav); + + val1 = h_helper.val[8*ib+4] + ((shb[ib] >> 8) & 0x7000) + o_helper.val[ib]; + val2 = h_helper.val[8*ib+5] + ((shb[ib] >> 11) & 0x7000) + o_helper.val[ib]; + val3 = h_helper.val[8*ib+6] + ((shb[ib] >> 14) & 0x7000) + o_helper.val[ib]; + val4 = h_helper.val[8*ib+7] + ((shb[ib] >> 17) & 0x7000) + o_helper.val[ib]; + auto xval2 = trellis.gen8bh(val1, val2, val3, val4, scale, dav); + + _mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+ 0), xval1); + _mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+16), xval2); } } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])) + dav*row_sum[iy]); - } + y += stride_y; + } -*/ +} + +template +void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; @@ -759,7 +755,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in float row_sum[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) { y[iy] = (const ggml_bf16_t *)info.src1_row(iy); - //row_sum[iy] = 0; auto sum = _mm256_setzero_ps(); auto one = _mm512_cvtneps_pbh(_mm512_set1_ps(1.f)); for (int i = 0; i < n/16; ++i) sum = _mm256_dpbf16_ps(sum, one, __m256bh(_mm256_loadu_si256((const __m256i *)y[iy] + i))); @@ -782,7 +777,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in const uint8_t * ql = (const uint8_t *)(shb + 8); auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0)); auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1)); - //const uint32_t * qh = (const uint32_t *)(ql + kNumGroups); auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups))); h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00))); h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00))); @@ -792,15 +786,6 @@ void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& in for (int ib = 0; ib < QK_K/32; ++ib) { auto scale = _mm512_set1_ps(s_helper.val[ib]); - // qh[(Q::kNg*ib + j)%(kNumGroups/2)] -> qh[(8*ib+j)%32], j = 0...7 - // ib = 0 -> 0....7 (uint8_t) -> 0, 1 (uint32_t), shift = 0 - // ib = 1 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 0 - // ib = 2 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 0 - // ib = 3 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 0 - // ib = 4 -> 0....7 (uint8_t) -> 1, 1 (uint32_t), shift = 4 - // ib = 5 -> 8...15 (uint8_t) -> 2, 3 (uint32_t), shift = 4 - // ib = 6 -> 16..23 (uint8_t) -> 4, 5 (uint32_t), shift = 4 - // ib = 7 -> 24..31 (uint8_t) -> 6, 7 (uint32_t), shift = 4 uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib]; uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib]; uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib]; @@ -882,11 +867,12 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * #ifdef __AVX512BF16__ case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (ggml_bf16_t *)y, stride_y, nrc_x); break; #else case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; -#endif case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; +#endif default: return false; } return true; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f2e9f68ef..d2bf56947 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -239,11 +239,12 @@ struct MulMat { #ifdef __AVX512BF16__ case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type; #else case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; #endif - //case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; default: break; } #else