diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5bb75d322..d9262bed6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1585,6 +1585,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq2_kt_q8_k, #ifdef __ARM_NEON .vec_dot_type = GGML_TYPE_F16, +#elif defined __AVX512BF16__ + .vec_dot_type = GGML_TYPE_BF16, #else .vec_dot_type = GGML_TYPE_F32, #endif @@ -1602,6 +1604,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq3_kt_q8_k, #ifdef __ARM_NEON .vec_dot_type = GGML_TYPE_F16, +#elif defined __AVX512BF16__ + .vec_dot_type = GGML_TYPE_BF16, #else .vec_dot_type = GGML_TYPE_F32, #endif @@ -1619,6 +1623,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq4_kt_q8_k, #ifdef __ARM_NEON .vec_dot_type = GGML_TYPE_F16, +#elif defined __AVX512BF16__ + .vec_dot_type = GGML_TYPE_BF16, #else .vec_dot_type = GGML_TYPE_F32, #endif diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index bc7bcf8b6..93edc0abf 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -28,6 +28,7 @@ inline float trellis_gen(uint32_t& val, uint32_t* s) { return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]); } +template struct Trellis1 { constexpr static uint32_t kmask = 0x8fff8fff; constexpr static uint32_t km32 = 0x3b603b60; @@ -47,16 +48,59 @@ struct Trellis1 { constexpr static uint32_t kb6 = kb5*ka+kb; constexpr static uint32_t ka7 = ka6*ka; constexpr static uint32_t kb7 = kb6*ka+kb; - const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7); - const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); + const __m256i mka = is4 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7); + const __m256i mkb = is4 ? _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3) : _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); const __m256i mask1 = _mm256_set1_epi32(kmask); const __m256i mask2 = _mm256_set1_epi32(km32); - +#ifdef __AVX512BF16__ + const __m512i shuf1 = load_shuffle(0); + const __m512i shuf2 = load_shuffle(1); + static __m512i load_shuffle(int i) { + static const int32_t k_idx[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return _mm512_loadu_si512((const __m512i *)k_idx + i); + } + inline __m256 gen8(__m256i i8) const { + auto v = _mm512_cvtph_ps(i8); + v = _mm512_add_ps(_mm512_permutexvar_ps(shuf1, v), _mm512_permutexvar_ps(shuf2, v)); + return _mm512_castps512_ps256(v); + } + inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale) const { + auto v1 = _mm512_cvtph_ps(i8_1); + auto v2 = _mm512_cvtph_ps(i8_2); + auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2); + auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2); + auto v = _mm512_mul_ps(scale, _mm512_add_ps(vs1, vs2)); + return __m256i(_mm512_cvtneps_pbh(v)); + } + inline __m256i gen8bh(__m256i i8_1, __m256i i8_2, __m512 scale, __m512 offset) const { + auto v1 = _mm512_cvtph_ps(i8_1); + auto v2 = _mm512_cvtph_ps(i8_2); + auto vs1 = _mm512_permutex2var_ps(v1, shuf1, v2); + auto vs2 = _mm512_permutex2var_ps(v1, shuf2, v2); + auto v = _mm512_fmadd_ps(scale, _mm512_add_ps(vs1, vs2), offset); + return __m256i(_mm512_cvtneps_pbh(v)); + } + inline __m256i gen8bh(uint32_t val1, uint32_t val2, __m512 scale) const { + return gen8bh(next8(val1), next8(val2), scale); + } + inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale) const { + return gen8bh(next8(val1, val2), next8(val3, val4), scale); + } + inline __m256i gen8bh(uint32_t val1, uint32_t val2, uint32_t val3, uint32_t val4, __m512 scale, __m512 offset) const { + return gen8bh(next8(val1, val2), next8(val3, val4), scale, offset); + } +#endif inline __m256i next8(uint32_t val) const { auto mval = _mm256_set1_epi32(val); auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); return _mm256_xor_si256(_mm256_and_si256(mres, mask1), mask2); } + inline __m256i next8(uint32_t val1, uint32_t val2) const { + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_xor_si256(_mm256_and_si256(mres, _mm256_set1_epi32(kmask)), _mm256_set1_epi32(km32)); + } }; inline __m256 trellis_gen8(__m256i i8) { @@ -97,6 +141,93 @@ struct Trellis2 { } }; +#ifdef __AVX512BF16__ +template +void mul_mat_iq2_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + union { __m256 vec; float val[8]; } s_helper; + + constexpr int k_acc = 2 * nrc_y; + __m256 accd[k_acc]; + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_cvtepi32_ps(s32); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = __m256bh(trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale)); + auto xval2 = __m256bh(trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K + 32*ib + 0))); + auto y2 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K + 32*ib + 16))); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], y1, xval1); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], y2, xval2); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, d*hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1]))); + } + } +} + +void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + union { __m256 vec; float val[8]; } s_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto vd = _mm256_set1_ps(*dptr * 31.75f * 1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale); + auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale); + _mm256_storeu_si256((__m256i *)(y + i*QK_K + 32*ib + 0), xval1); + _mm256_storeu_si256((__m256i *)(y + i*QK_K + 32*ib + 16), xval2); + } + } + y += stride_y; + } +} +#endif + void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -315,6 +446,116 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +#ifdef __AVX512BF16__ +void iqk_dequantize_iq3_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + __m256i all_signs[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + if (dptr[0] < 0) { + printf("Oops: row scale is %g\n", dptr[0]); + GGML_ABORT("Fatal error"); + } + auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32)); + for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j))); + auto mask = _mm256_set1_epi16(0x01); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale); + auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale); + auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000)); + auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000)); + auto x1 = _mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff))); + auto x2 = _mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff))); + _mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+ 0), x1); + _mm256_storeu_si256((__m256i *)(y+i*QK_K+32*ib+16), x2); + mask = _mm256_slli_epi16(mask, 1); + } + } + y += stride_y; + } +} + +template +void mul_mat_iq3_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + __m256i all_signs[2]; + + __m256 accd[2*nrc_y]; + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + if (dptr[0] < 0) { + printf("Oops: row scale is %g\n", dptr[0]); + GGML_ABORT("Fatal error"); + } + auto vd = _mm256_set1_ps(dptr[0] * 31.75f * 1.015f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(s32)); + for (int j = 0; j < 2; ++j) all_signs[j] = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)(qh + 16*j))); + auto mask = _mm256_set1_epi16(0x01); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + auto xval1 = trellis.gen8bh(trellis.next8(ql[4*ib+0]+4096), trellis.next8(ql[4*ib+1]+4096), scale); + auto xval2 = trellis.gen8bh(trellis.next8(ql[4*ib+2]+4096), trellis.next8(ql[4*ib+3]+4096), scale); + auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[0], mask), mask), _mm256_set1_epi16(0x8000)); + auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi16(_mm256_and_si256(all_signs[1], mask), mask), _mm256_set1_epi16(0x8000)); + mask = _mm256_slli_epi16(mask, 1); + auto x1 = __m256bh(_mm256_or_si256(sign1, _mm256_and_si256(xval1, _mm256_set1_epi16(0x7fff)))); + auto x2 = __m256bh(_mm256_or_si256(sign2, _mm256_and_si256(xval2, _mm256_set1_epi16(0x7fff)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+ 0))); + auto y2 = __m256bh(_mm256_loadu_si256((const __m256i *)(y[iy]+i*QK_K+32*ib+16))); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], y1, x1); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], y2, x2); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1]))); + } + } +} +#endif + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -370,7 +611,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis1 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; @@ -442,16 +683,167 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +#ifdef __AVX512BF16__ +void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, ggml_bf16_t * y, size_t stride_y, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; + + union { __m512i vec[2]; uint16_t val[64]; } h_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto dav = _mm512_set1_ps(dptr[1]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0)); + auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1)); + auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups))); + h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00))); + h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00))); + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + + uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib]; + uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib]; + uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib]; + uint32_t val4 = h_helper.val[8*ib+3] + ((shb[ib] >> 5) & 0x7000) + o_helper.val[ib]; + auto xval1 = trellis.gen8bh(val1, val2, val3, val4, scale, dav); + + val1 = h_helper.val[8*ib+4] + ((shb[ib] >> 8) & 0x7000) + o_helper.val[ib]; + val2 = h_helper.val[8*ib+5] + ((shb[ib] >> 11) & 0x7000) + o_helper.val[ib]; + val3 = h_helper.val[8*ib+6] + ((shb[ib] >> 14) & 0x7000) + o_helper.val[ib]; + val4 = h_helper.val[8*ib+7] + ((shb[ib] >> 17) & 0x7000) + o_helper.val[ib]; + auto xval2 = trellis.gen8bh(val1, val2, val3, val4, scale, dav); + + _mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+ 0), xval1); + _mm256_storeu_si256((__m256i *)(y + i*QK_K+32*ib+16), xval2); + } + } + + y += stride_y; + + } +} + +template +void mul_mat_iq4_kt_BF16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis1 trellis; + + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; + + __m256 accd[2*nrc_y]; + const ggml_bf16_t * y[nrc_y]; + float row_sum[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + auto sum = _mm256_setzero_ps(); + auto one = _mm512_cvtneps_pbh(_mm512_set1_ps(1.f)); + for (int i = 0; i < n/16; ++i) sum = _mm256_dpbf16_ps(sum, one, __m256bh(_mm256_loadu_si256((const __m256i *)y[iy] + i))); + row_sum[iy] = hsum_float_8(sum); + } + + union { __m512i vec[2]; uint16_t val[64]; } h_helper; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto dav = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < 2*nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + auto vql1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+0)); + auto vql2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)ql+1)); + auto vqh = _mm512_cvtepu8_epi16(_mm256_loadu_si256((const __m256i *)(ql + kNumGroups))); + h_helper.vec[0] = _mm512_add_epi16(vql1, _mm512_and_si512(_mm512_slli_epi16(vqh, 8), _mm512_set1_epi16(0xf00))); + h_helper.vec[1] = _mm512_add_epi16(vql2, _mm512_and_si512(_mm512_slli_epi16(vqh, 4), _mm512_set1_epi16(0xf00))); + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scale = _mm512_set1_ps(s_helper.val[ib]); + + uint32_t val1 = h_helper.val[8*ib+0] + ((shb[ib] << 4) & 0x7000) + o_helper.val[ib]; + uint32_t val2 = h_helper.val[8*ib+1] + ((shb[ib] << 1) & 0x7000) + o_helper.val[ib]; + uint32_t val3 = h_helper.val[8*ib+2] + ((shb[ib] >> 2) & 0x7000) + o_helper.val[ib]; + uint32_t val4 = h_helper.val[8*ib+3] + ((shb[ib] >> 5) & 0x7000) + o_helper.val[ib]; + auto xval1 = trellis.gen8bh(val1, val2, val3, val4, scale); + + val1 = h_helper.val[8*ib+4] + ((shb[ib] >> 8) & 0x7000) + o_helper.val[ib]; + val2 = h_helper.val[8*ib+5] + ((shb[ib] >> 11) & 0x7000) + o_helper.val[ib]; + val3 = h_helper.val[8*ib+6] + ((shb[ib] >> 14) & 0x7000) + o_helper.val[ib]; + val4 = h_helper.val[8*ib+7] + ((shb[ib] >> 17) & 0x7000) + o_helper.val[ib]; + auto xval2 = trellis.gen8bh(val1, val2, val3, val4, scale); + + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+ 0)); + auto y2 = _mm256_loadu_si256((const __m256i *)(y[iy] + i*QK_K+32*ib+16)); + accd[2*iy+0] = _mm256_dpbf16_ps(accd[2*iy+0], __m256bh(y1), __m256bh(xval1)); + accd[2*iy+1] = _mm256_dpbf16_ps(accd[2*iy+1], __m256bh(y2), __m256bh(xval2)); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accd[2*iy], accd[2*iy+1])) + dav*row_sum[iy]); + } + } +} +#endif + } // namespace bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { - return false; - } + if (ne00%QK_K != 0) return false; func16 = nullptr; +#ifdef __AVX512BF16__ + if (typeA == GGML_TYPE_IQ2_KT) { + if (typeB != GGML_TYPE_BF16) return false; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_BF16_T, kernels); + return true; + } + if (typeA == GGML_TYPE_IQ3_KT) { + if (typeB != GGML_TYPE_BF16) return false; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_BF16_T, kernels); + return true; + } + if (typeA == GGML_TYPE_IQ4_KT) { + if (typeB != GGML_TYPE_BF16) return false; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_BF16_T, kernels); + return true; + } +#endif + + if (ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } + switch (typeA) { case GGML_TYPE_IQ2_KT: IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F32_T, kernels); @@ -472,9 +864,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_BF16 : type; + case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_BF16 : type; +#else case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; +#endif default: break; } #else @@ -606,7 +612,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: - return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; + return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: