From e558992f0ce34615bbc74c7ffc00fbaf617dba34 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 12:30:37 +0300 Subject: [PATCH 01/33] New iq4_kt trellis The new trellis generates int8_t values via sum_as_uint8_t[(ka * idx + kb) & 0x3f33f3f3f] - 126. CUDA dequantize works. AVX2 case Ny > 32 works, and we get 273 t/s for L3-8B. PPL is on par or even slightly lower than original QTIP trellis. --- ggml/src/ggml-cuda/convert.cu | 12 +-- ggml/src/ggml.c | 11 +- ggml/src/iqk/iqk_gemm_ktquants.cpp | 157 +++++++++++++++++++++++++++-- ggml/src/iqk/iqk_mul_mat.cpp | 6 +- ggml/src/iqk/iqk_quantize.cpp | 29 ++++-- 5 files changed, 181 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 01b7250e3..b1c447ab3 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -343,13 +343,9 @@ inline __device__ int nearest_int(float fval) { float __device__ __forceinline__ trellis_next(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; - uint32_t s; - const half * h = (const half *)&s; val = ka*val + kb; - s = (val & kmask) ^ km32; - return (float)(h[0]+h[1]); + //return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, 0x82828282); + return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126); } template @@ -367,7 +363,7 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f; for (int j = 0; j < 8; ++j) { y[j] = dl * trellis_next(idx); } @@ -401,7 +397,7 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t ii = blockIdx.x; int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); - float scale = dptr[0] * 31.75f * 1.01f; + float scale = dptr[0] * 1.00f; float row_av = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); const int64_t i = ii - (row*n_per_row)/QK_K; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 69b1b46d7..b1b9b57bd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1630,11 +1630,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref, .vec_dot = vec_dot_iq4_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif +//#ifdef __ARM_NEON +// .vec_dot_type = GGML_TYPE_F16, +//#else +// .vec_dot_type = GGML_TYPE_F32, +//#endif .nrows = 1, .row_meta_size = 8, }, diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index bc7bcf8b6..0a8d2d037 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -97,6 +97,43 @@ struct Trellis2 { } }; + +struct Trellis3 { + constexpr static uint32_t ka = 89226354; + constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t kb1 = kb*ka+kb; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t kb2 = kb1*ka+kb; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t kb3 = kb2*ka+kb; + const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); + const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); + const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + + inline __m256i next8(uint32_t val1, uint32_t val2) const { + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + } + inline __m256 gen8(uint32_t val1, uint32_t val2) const { + auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); + auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); + return _mm256_cvtepi32_ps(i8); + } + inline __m256i next32(const uint32_t * val) const { + __m256i aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); + aux[i] = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), i8); + } + aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } +}; + void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -315,19 +352,121 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +// Q8_0 repacking: +// for (int ib = 0; ib < nblock; ++ib) { +// for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; +// for (int l = 0; l < 4; ++l) { +// for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { +// y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; +// y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; +// as uint32_t +// y[ib].qs[8*l+k+ 0] = x8[k][ib].qs[l+ 0]; +// y[ib].qs[8*l+k+32] = x8[k][ib].qs[l+16]; +// } +// } +// } + +void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[16]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 2); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + //for (int k = 0; k < 8; ++k) { + // auto shb = x8[k][i].qs; + // const uint8_t * ql = (const uint8_t *)(shb + 8); + // const uint8_t * qh = ql + kNumGroups; + // for (int ib = 0; ib < 4; ++ib) { + // uint32_t offset1 = ((shb[ib+0] & 1) << 15) + 4096; + // uint32_t offset2 = ((shb[ib+4] & 1) << 15) + 4096; + // for (int j = 0; j < 4; ++j) { + // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + // idx[64*ib + 16*j + k ] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + // idx[64*ib + 16*j + k + 8] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + // idx[64*ib + 16*j + k + 256] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + // idx[64*ib + 16*j + k + 264] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + // //uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + // //uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + // //uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + // //uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + // //auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); + // //auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); + // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); + // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); + // } + // } + //} + //for (int j = 0; j < 64; ++j) { + // _mm256_storeu_si256((__m256i *)y[j/8].qs+(j%8), trellis.next32(idx+8*j)); + //} + //int shift1 = 8 - 4*(ib/4); + //for (int j = 0; j < 4; ++j) { + // for (int k = 0; k < 8; ++k) { + // const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + // const uint8_t * qh = ql + kNumGroups; + // const uint32_t sh = x8[k][i].qs[ib] >> (8 + 6*j); + // idx[k+0] = ql[8*ib+2*j+0] + ((qh[8*(ib%4)+2*j+0] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + // idx[k+8] = ql[8*ib+2*j+1] + ((qh[8*(ib%4)+2*j+1] << shift1) & 0xf00) + ((sh & 56) << 9) + idx0[k]; + // } + // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, trellis.next32(idx+0)); + // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, trellis.next32(idx+8)); + //} + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + _mm256_storeu_si256((__m256i *)y[ib].qs+j, trellis.next32(idx)); + } + } + y += 8; // = QK_K/32; + } + + } +} + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis3 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto d = _mm256_set1_ps(dptr[0]); auto dav = _mm256_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); @@ -349,8 +488,8 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav); - auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav); + auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); + auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); @@ -370,7 +509,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis3 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; @@ -389,7 +528,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto d = _mm256_set1_ps(dptr[0]); auto dav = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); @@ -413,8 +552,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); - auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); + auto x_val1 = _mm256_mul_ps(scale1, trellis.gen8(val1, val3)); + auto x_val2 = _mm256_mul_ps(scale2, trellis.gen8(val2, val4)); if constexpr (nrc_y == 1) { auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0); auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128); @@ -474,7 +613,7 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * switch (type) { case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; - case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; default: return false; } return true; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 6925e6a6a..ce67159c8 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -236,9 +236,6 @@ struct MulMat { static inline ggml_type is_dequant_better(ggml_type type, int nrc_y) { #ifdef __AVX2__ switch (type) { - case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type; @@ -267,6 +264,9 @@ struct MulMat { case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; + case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; + case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index abd4be610..7061c8bfd 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7408,7 +7408,7 @@ class QuantizerIQKT { constexpr static int kNg = kBlockSize/kGroupSize; constexpr static int kNblock = kSuperBlockSize/kBlockSize; constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 - constexpr static float kScale = 31.75f; + constexpr static float kScale = 1.f; //31.75f; constexpr static bool kVerbose = false; QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); @@ -7421,15 +7421,19 @@ class QuantizerIQKT { static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; + //constexpr uint32_t kmask = 0x8fff8fff; + //constexpr uint32_t km32 = 0x3b603b60; uint32_t x = i + offset; + uint32_t s; + auto i8 = (const int8_t *)&s; for (int k = 0; k < kGroupSize; ++k) { x = ka*x + kb; - uint32_t s = (x & kmask) ^ km32; - float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); - if constexpr (is_abs) result[k] = scale*std::abs(val); - else result[k] = scale*val; + s = x & 0x3f3f3f3f; + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + //uint32_t s = (x & kmask) ^ km32; + //float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + //if constexpr (is_abs) result[k] = scale*std::abs(val); + //else result[k] = scale*val; } } @@ -8209,7 +8213,7 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) { assert(k % QuantizerIQ2KT::kSuperBlockSize == 0); #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return; #endif const int nb = k / QuantizerIQ2KT::kSuperBlockSize; const float * dptr = (const float *)x; @@ -8560,7 +8564,10 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f row_av += x[j]; amax_row = std::max(amax_row, std::abs(x[j])); } - row_av /= n_per_row; + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + //row_av /= n_per_row; + row_av = 0; + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! dptr[1] = row_av; if (!amax_row) { dptr[0] = 0.f; @@ -8593,7 +8600,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f continue; } float best = 0; - float scale_0 = std::max(92.f, 127.f*amax/amax_row); + float scale_0 = std::max(90.f, 124.f*amax/amax_row); for (int itry = -kNtry; itry <= kNtry; ++itry) { quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx); @@ -8724,7 +8731,7 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return; #endif using Q = QuantizerIQ4KT; assert(k % Q::kSuperBlockSize == 0); From de0b38dcdcf0529deaee919e921b0c74ef54714d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 16:18:58 +0300 Subject: [PATCH 02/33] Something is not working with the AVX2 dot product --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 264 ++++++++++++++++++++++------- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 2 files changed, 208 insertions(+), 58 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0a8d2d037..5152f33dd 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -120,11 +120,13 @@ struct Trellis3 { auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); return _mm256_cvtepi32_ps(i8); } + template inline __m256i next32(const uint32_t * val) const { + const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); __m256i aux[4]; for (int i = 0; i < 4; ++i) { auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); - aux[i] = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), i8); + aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); } aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -352,20 +354,6 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } -// Q8_0 repacking: -// for (int ib = 0; ib < nblock; ++ib) { -// for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; -// for (int l = 0; l < 4; ++l) { -// for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { -// y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; -// y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; -// as uint32_t -// y[ib].qs[8*l+k+ 0] = x8[k][ib].qs[l+ 0]; -// y[ib].qs[8*l+k+32] = x8[k][ib].qs[l+16]; -// } -// } -// } - void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -397,46 +385,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); - //for (int k = 0; k < 8; ++k) { - // auto shb = x8[k][i].qs; - // const uint8_t * ql = (const uint8_t *)(shb + 8); - // const uint8_t * qh = ql + kNumGroups; - // for (int ib = 0; ib < 4; ++ib) { - // uint32_t offset1 = ((shb[ib+0] & 1) << 15) + 4096; - // uint32_t offset2 = ((shb[ib+4] & 1) << 15) + 4096; - // for (int j = 0; j < 4; ++j) { - // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - // idx[64*ib + 16*j + k ] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - // idx[64*ib + 16*j + k + 8] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - // idx[64*ib + 16*j + k + 256] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - // idx[64*ib + 16*j + k + 264] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - // //uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - // //uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - // //uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - // //uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - // //auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); - // //auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); - // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); - // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); - // } - // } - //} - //for (int j = 0; j < 64; ++j) { - // _mm256_storeu_si256((__m256i *)y[j/8].qs+(j%8), trellis.next32(idx+8*j)); - //} - //int shift1 = 8 - 4*(ib/4); - //for (int j = 0; j < 4; ++j) { - // for (int k = 0; k < 8; ++k) { - // const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); - // const uint8_t * qh = ql + kNumGroups; - // const uint32_t sh = x8[k][i].qs[ib] >> (8 + 6*j); - // idx[k+0] = ql[8*ib+2*j+0] + ((qh[8*(ib%4)+2*j+0] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; - // idx[k+8] = ql[8*ib+2*j+1] + ((qh[8*(ib%4)+2*j+1] << shift1) & 0xf00) + ((sh & 56) << 9) + idx0[k]; - // } - // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, trellis.next32(idx+0)); - // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, trellis.next32(idx+8)); - //} int shift1 = 8 - 4*(ib/4); for (int j = 0; j < 8; ++j) { for (int k = 0; k < 8; ++k) { @@ -454,6 +402,92 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } +/* +template +void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[8]; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[8]; + + union { float f; uint32_t u; } bf16_helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 2); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f)); + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + xv[j] = trellis.next32(idx); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const auto& yb = y[iy][2*i+ib/4]; + int i4 = ib%4; + auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0); + auto vy = MM256_SET_M128I(vy8, vy8); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50)); + sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff)); + vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1); + vy = MM256_SET_M128I(vy8, vy8); + sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50)); + sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff)); + bf16_helper.u = yb.d[i4] << 16; + auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f)); + accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]); + bf16_helper.u = yb.d[i4+4] << 16; + accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, accd[iy]); + } + } +} +*/ + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -503,6 +537,112 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t } } +template +void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + union { __m256i vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + uint32_t values[64]; + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + } + }; + + auto m126 = _mm256_set1_ps(-126.f); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + iscales = _mm256_sub_epi32(iscales, _mm256_set1_epi32(64)); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(iscales)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < 4; ++ib) { + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + } + } + // sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8) + // d4*d8*sum[x_i*y_i] - 126*d4*m8 + for (int i128 = 0; i128 < 2; ++i128) { + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + //auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16)); + //auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + //m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f)); + //for (int k = 0; k < 4; ++k) { + // xv[k] = trellis.next32(values + 32*i128 + 8*k); + // auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k); + // dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + //} + //accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]); + //accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + template void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -585,11 +725,21 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { + if (ne00%QK_K != 0) return false; + + func16 = nullptr; + + if (typeA == GGML_TYPE_IQ4_KT) { + if (typeB == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels); + return true; + } return false; } - func16 = nullptr; + if (ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } switch (typeA) { case GGML_TYPE_IQ2_KT: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ce67159c8..71eb42a9b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -815,7 +815,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: - return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; + return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: From 6d6e6e39c97d71c8cff7a43c9ef77ccdb35994a6 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 17:43:36 +0300 Subject: [PATCH 03/33] New iq4_kt: CUDA MMVQ --- ggml/src/ggml-cuda/common.cuh | 7 +++++ ggml/src/ggml-cuda/iqk_mmvq.cu | 46 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 ++++ ggml/src/ggml-cuda/mmvq.cu | 4 +++ 4 files changed, 62 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8f3d2a265..291378f42 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -648,6 +648,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index e5a224b47..789670b56 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -433,6 +433,44 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq4_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + 2*sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + //const int8_t * q8 = bq8_1[ib32].qs; + const int ls = (bq4->qs[ib32] & 0xff) >> 1; + const float dl = scale * (ls - 64); + const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096; + auto ql = (const uint8_t *)(bq4->qs + 8); + auto qh = ql + 64; + ql += 8*ib32; + qh += 8*(ib32%4); + const int shift1 = 8 - 4*(ib32/4); + int sumi = 0; + for (int j = 0; j < 8; ++j) { + const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j); + uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val = ka*val + kb; + //int s = val & km; + //sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126); + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[j], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ4_KSS_Q8_1_MMVQ 4 #define VDR_IQ4_KSS_Q8_1_MMQ 4 @@ -1217,6 +1255,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 17bf5ad2c..97c172f33 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -100,3 +100,8 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 73caabab7..6c230050d 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -526,6 +526,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm break; case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + case GGML_TYPE_IQ4_KT: + mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); @@ -687,6 +690,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: + case GGML_TYPE_IQ4_KT: return true; default: return false; From b5524af7a4439bf7c2db031b686778b225c88366 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 18:21:43 +0300 Subject: [PATCH 04/33] New iq4_kt: CUDA MMQ --- ggml/src/ggml-cuda/mmq.cu | 4 ++ ggml/src/ggml-cuda/mmq.cuh | 84 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_quantize.cpp | 11 ++++- 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a13be11b1..21bf9003a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -100,6 +100,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KS_R4: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ4_KT: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case(ctx, args, stream); break; @@ -172,6 +175,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ4_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 608de8f00..7d829c7e6 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -93,6 +93,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ4_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -202,6 +203,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } } @@ -250,6 +252,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } } @@ -2790,6 +2793,79 @@ template static __device__ __forceinlin } +template static __device__ __forceinline__ void load_tiles_iq4_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + 2*sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto shb = bxi->qs; + const auto ql = (const uint8_t *)(shb + 8); + const auto qh = ql + 64; + const uint32_t sh = shb[ib32] >> (8 + 6*j); + uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); + uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12); + uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9); + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val1 = ka*val1 + kb; + val2 = ka*val2 + kb; + v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k; + v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 2) + kbx0; + const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3382,6 +3458,13 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; @@ -3843,6 +3926,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 7061c8bfd..0fd4ed862 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7490,6 +7490,7 @@ QuantizerIQKT::QuantizerIQKT(int num_c set_values(i, data, kScale, offset); data += kGroupSize; } + if (num_clusters == 0) return; // Make 128 clusters. // Note: we get a slightly better result by using 64 clusters // at the expense of almost doubling the quantization time. @@ -8540,6 +8541,14 @@ const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { return *quantizer1; } +const QuantizerIQ4KT& iq4kt_dequantizer() { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unique_ptr dequantizer; + if (!dequantizer) dequantizer = std::make_unique(0, 0, 4096); + return *dequantizer; +} + void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { constexpr float kSigmaScale = 2.0f; @@ -8741,7 +8750,7 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { const float d = dptr[0] * Q::kScale; const float row_av = dptr[1]; x = (const block_iq4_kt *)(dptr + 2); - auto& deq = iq4kt_quantizer(); + auto& deq = iq4kt_dequantizer(); for (int ibl = 0; ibl < nb; ++ibl) { auto shb = x[ibl].qs; auto ql = (const uint8_t *)(shb + Q::kNblock); From 6ba96c8b33f17056bb008f4c3e50faff4efe3597 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 18:35:10 +0300 Subject: [PATCH 05/33] For now have only iq4_kt use the new trellis --- ggml/src/iqk/iqk_quantize.cpp | 60 +++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 0fd4ed862..b45478298 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7397,7 +7397,7 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) { } namespace { -template +template class QuantizerIQKT { static_assert(group_size == 8 || group_size == 4); static_assert(block_size >= 8 && block_size%8 == 0); @@ -7408,7 +7408,7 @@ class QuantizerIQKT { constexpr static int kNg = kBlockSize/kGroupSize; constexpr static int kNblock = kSuperBlockSize/kBlockSize; constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 - constexpr static float kScale = 1.f; //31.75f; + constexpr static float kScale = is_int ? 1.f : 31.75f; constexpr static bool kVerbose = false; QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); @@ -7421,19 +7421,25 @@ class QuantizerIQKT { static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; - //constexpr uint32_t kmask = 0x8fff8fff; - //constexpr uint32_t km32 = 0x3b603b60; uint32_t x = i + offset; - uint32_t s; - auto i8 = (const int8_t *)&s; - for (int k = 0; k < kGroupSize; ++k) { - x = ka*x + kb; - s = x & 0x3f3f3f3f; - result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); - //uint32_t s = (x & kmask) ^ km32; - //float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); - //if constexpr (is_abs) result[k] = scale*std::abs(val); - //else result[k] = scale*val; + if constexpr (is_int) { + uint32_t s; + auto i8 = (const int8_t *)&s; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + s = x & 0x3f3f3f3f; + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } + } else { + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + uint32_t s = (x & kmask) ^ km32; + float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + if constexpr (is_abs) result[k] = scale*std::abs(val); + else result[k] = scale*val; + } } } @@ -7482,8 +7488,8 @@ class QuantizerIQKT { float m_mid[4*kGroupSize]; }; -template -QuantizerIQKT::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { +template +QuantizerIQKT::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { m_values.resize(kNumVal*kGroupSize); float * data = m_values.data(); for (int i = 0; i < kNumVal; ++i) { @@ -7499,8 +7505,8 @@ QuantizerIQKT::QuantizerIQKT(int num_c m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values); } -template -std::pair QuantizerIQKT::find_best_scale( +template +std::pair QuantizerIQKT::find_best_scale( const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumq2 = 0; #ifdef __AVX2__ @@ -7532,8 +7538,8 @@ std::pair QuantizerIQKT: return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); } -template -float QuantizerIQKT::find_best_inverse_scale( +template +float QuantizerIQKT::find_best_inverse_scale( const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumx2 = 0; #ifdef __AVX2__ @@ -7565,8 +7571,8 @@ float QuantizerIQKT::find_best_inverse return sumx2 > 0 ? sumqx/sumx2 : 0.f; } -template -void QuantizerIQKT::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { +template +void QuantizerIQKT::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { if (!d) { std::memset(best_idx, 0, kNg*sizeof(int)); return; @@ -7744,8 +7750,8 @@ void QuantizerIQKT::find_best_match(fl #endif } -template -std::vector> QuantizerIQKT::finalize_clusters(int num_neighbours, +template +std::vector> QuantizerIQKT::finalize_clusters(int num_neighbours, const std::vector& values, const std::vector& clusters, std::vector>& c_values) { int ncluster = clusters.size()/kGroupSize; std::vector> p_in_cluster(ncluster); @@ -7831,8 +7837,8 @@ std::vector> QuantizerIQKT -std::vector QuantizerIQKT::cluster_points(const std::vector& points, int ncluster, int niter, float * mid) { +template +std::vector QuantizerIQKT::cluster_points(const std::vector& points, int ncluster, int niter, float * mid) { constexpr int ndim = kGroupSize; GGML_ASSERT(points.size() % ndim == 0); int npoint = points.size() / ndim; @@ -8526,7 +8532,7 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace{ -using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>; +using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15, false, true>; const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { static std::mutex mutex; From 36fba1fff20b29055d8f73d6fdc50a9d44718f2d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 18:51:02 +0300 Subject: [PATCH 06/33] Fix iq2_kt that got broken along the way --- ggml/src/ggml-cuda/convert.cu | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b1c447ab3..c6e064f6e 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -340,14 +340,25 @@ inline __device__ int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -float __device__ __forceinline__ trellis_next(uint32_t& val) { +int __device__ __forceinline__ trellis_next_int(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; val = ka*val + kb; - //return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, 0x82828282); return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126); } +float __device__ __forceinline__ trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + uint32_t s; + const half * h = (const half *)&s; + val = ka*val + kb; + s = (val & kmask) ^ km32; + return (float)(h[0]+h[1]); +} + template static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { @@ -363,7 +374,7 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; for (int j = 0; j < 8; ++j) { y[j] = dl * trellis_next(idx); } @@ -398,7 +409,6 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); float scale = dptr[0] * 1.00f; - float row_av = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); const int64_t i = ii - (row*n_per_row)/QK_K; @@ -419,8 +429,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int ls = ((shb[ib32] & 0xff) >> 1) - 64; const float dl = scale * ls; for (int j = 0; j < 4; ++j) { - y[j+0] = dl * trellis_next(idx1) + row_av; - y[j+4] = dl * trellis_next(idx2) + row_av; + y[j+0] = dl * trellis_next_int(idx1); + y[j+4] = dl * trellis_next_int(idx2); } } From 78411343cc14210b3a311176a11352edc51ee26d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 19:12:09 +0300 Subject: [PATCH 07/33] New iq4_kt: AVX2 dot product finally works We get 13.6 t/s vs 8.4 t/s with the f16 trellis and f32 arithmetic. Still somewhat slower than other quants, but no longer pathetic. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 111 ++--------------------------- 1 file changed, 7 insertions(+), 104 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 5152f33dd..d2774e2a0 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -402,92 +402,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } -/* -template -void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%QK_K == 0); - const int nb = n/QK_K; - constexpr int kNumGroups = 64; - - Trellis3 trellis; - - constexpr int k_acc = nrc_y; - - __m256 accd[k_acc]; - const block_q8_2_x4 * y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) { - y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); - } - - __m256i xv[8]; - - const block_iq4_kt * x8[8]; - float dkt[8]; - int32_t ls[8]; - uint32_t idx0[8], idx[8]; - - union { float f; uint32_t u; } bf16_helper; - - for (int ix = 0; ix < nrc_x; ix += 8) { - for (int k = 0; k < 8; ++k) { - const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); - dkt[k] = dptr[0]; - x8[k] = (const block_iq4_kt *)(dptr + 2); - } - auto vd = _mm256_loadu_ps(dkt); - - for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - for (int ib = 0; ib < QK_K/32; ++ib) { - for (int k = 0; k < 8; ++k) { - ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; - idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; - } - auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f)); - int shift1 = 8 - 4*(ib/4); - for (int j = 0; j < 8; ++j) { - for (int k = 0; k < 8; ++k) { - const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); - const uint8_t * qh = ql + kNumGroups; - const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); - idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; - } - xv[j] = trellis.next32(idx); - } - for (int iy = 0; iy < nrc_y; ++iy) { - const auto& yb = y[iy][2*i+ib/4]; - int i4 = ib%4; - auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0); - auto vy = MM256_SET_M128I(vy8, vy8); - auto sumi = _mm256_setzero_si256(); - sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50)); - sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff)); - vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1); - vy = MM256_SET_M128I(vy8, vy8); - sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50)); - sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff)); - bf16_helper.u = yb.d[i4] << 16; - auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f)); - accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]); - bf16_helper.u = yb.d[i4+4] << 16; - accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]); - } - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, accd[iy]); - } - } -} -*/ - void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -573,11 +487,12 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& auto compute_dot = [&dot, &xv] (const int8_t * y) { for (int k = 0; k < 4; ++k) { auto yv = _mm256_loadu_si256((const __m256i *)y + k); - dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); } }; - auto m126 = _mm256_set1_ps(-126.f); + //auto m126 = _mm256_set1_ps(-126.f); for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); @@ -609,30 +524,18 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; } } - // sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8) - // d4*d8*sum[x_i*y_i] - 126*d4*m8 for (int i128 = 0; i128 < 2; ++i128) { - for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); - //auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16)); - //auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); - //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); - //m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f)); - //for (int k = 0; k < 4; ++k) { - // xv[k] = trellis.next32(values + 32*i128 + 8*k); - // auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k); - // dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); - //} - //accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]); - //accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]); + //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); for (int iy = 0; iy < nrc_y; ++iy) { const block_q8_2_x4& yb = y[iy][2*i+i128]; auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); dy = _mm256_mul_ps(scales[i128], dy); auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); - auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); compute_dot(yb.qs); accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); - accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); } } } From 608e0f497b01cf3e2f34ba84d271f462a838d4cd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 19:33:33 +0300 Subject: [PATCH 08/33] New iq4_kt: fix vanilla AVX2 --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index d2774e2a0..28939fffd 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -117,7 +117,12 @@ struct Trellis3 { } inline __m256 gen8(uint32_t val1, uint32_t val2) const { auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); +#else + auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); + auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif return _mm256_cvtepi32_ps(i8); } template @@ -126,7 +131,12 @@ struct Trellis3 { __m256i aux[4]; for (int i = 0; i < 4; ++i) { auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); +#else + auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif } aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -487,8 +497,13 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& auto compute_dot = [&dot, &xv] (const int8_t * y) { for (int k = 0; k < 4; ++k) { auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif } }; From 68ef8a7ae9b393d827ad675b17d16d9e63093d25 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 08:31:56 +0300 Subject: [PATCH 09/33] New iq4_kt: NEON implementation We get very respectable PP-512 = 120 t/s. TG-128 is pathetic at 5.3 t/s, so 20+% slower than the f16 variant. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 207 ++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 2 files changed, 205 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 28939fffd..cdee7afc1 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1199,10 +1199,213 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +struct Trellis3 { + constexpr static uint32_t ka = 89226354; + constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t kb1 = kb*ka+kb; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t kb2 = kb1*ka+kb; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t kb3 = kb2*ka+kb; + const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3}; + const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3}; + const uint8x16_t shuffle = load_shuffle(); + + inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const { + uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)}; + result.val[0] = vmlaq_u32(mkb, mka, result.val[0]); + result.val[1] = vmlaq_u32(mkb, mka, result.val[1]); + return result; + } + //inline int8x16x2_t next32(const uint32_t * val) const { + // int8x16x4_t aux; + // int8x16x2_t result; + // for (int i = 0; i < 2; ++i) { + // auto i8 = next8(val[4*i+0], val[4*i+1]); + // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + // aux.val[0] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); + // aux.val[1] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); + // i8 = next8(val[4*i+2], val[4*i+3]); + // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + // aux.val[2] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); + // aux.val[3] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); + // result.val[i] = vqtbl4q_s8(aux, shuffle); + // } + // return result; + //} + // This works: + inline int8x16x2_t next32(const uint32_t * val) const { + uint16x8_t aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = next8(val[2*i+0], val[2*i+1]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); + auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); + aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); + } + int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; + return result; + } + static uint8x16_t load_shuffle() { + static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + return vld1q_u8(k_shuffle); + } +}; + +void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 2); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales1 = vmulq_f32(vd.val[0], vcvtq_f32_s32(vld1q_s32(ls+0))); + auto scales2 = vmulq_f32(vd.val[1], vcvtq_f32_s32(vld1q_s32(ls+4))); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + vst1q_s8_x2(y[ib].qs+32*j, trellis.next32(idx)); + } + } + y += 8; // = QK_K/32; + } + } +} + +template +void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + union { uint32x4x2_t vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + uint32_t values[64]; + int8x16x2_t xv[4]; + int32x4x4_t dot; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto vshb = vld1q_u32_x2(x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales1 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff)), 1)); + auto iscales2 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)), 1)); + iscales1 = vaddq_s32(iscales1, vdupq_n_s32(-64)); + iscales2 = vaddq_s32(iscales2, vdupq_n_s32(-64)); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(iscales1)); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(iscales2)); + o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); + o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); + for (int ib = 0; ib < 4; ++ib) { + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + } + } + for (int i128 = 0; i128 < 2; ++i128) { + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& yb = y[iy][2*i+i128]; + auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d))); + //auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16)); + //dy = vmulq_f32(scales.val[i128], dy); + auto sumi = compute_dot(yb.qs); + accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi)); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } +} + } bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { + + if (ne00%QK_K != 0) return false; + + func16 = nullptr; + + if (ggml_type(typeA) == GGML_TYPE_IQ4_KT) { + if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_0_x4_T, kernels); + return true; + } + return false; + } + //if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) { // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); // func16 = nullptr; @@ -1213,8 +1416,6 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_F16 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; - case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #endif From dcb464a4cbb7b183d0012d7c9f65143682cfe2a4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 08:43:02 +0300 Subject: [PATCH 10/33] New iq4_kt: slightly faster NEON --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index cdee7afc1..4e06b1b8f 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1237,17 +1237,32 @@ struct Trellis3 { // return result; //} // This works: + //inline int8x16x2_t next32(const uint32_t * val) const { + // uint16x8_t aux[4]; + // for (int i = 0; i < 4; ++i) { + // auto i8 = next8(val[2*i+0], val[2*i+1]); + // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + // auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); + // auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); + // aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); + // } + // int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; + // return result; + //} inline int8x16x2_t next32(const uint32_t * val) const { - uint16x8_t aux[4]; - for (int i = 0; i < 4; ++i) { - auto i8 = next8(val[2*i+0], val[2*i+1]); + int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; + for (int i = 0; i < 2; ++i) { + auto i8 = next8(val[4*i+0], val[4*i+1]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + i8 = next8(val[4*i+2], val[4*i+3]); i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); - auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); - aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); + auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); } - int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; return result; } static uint8x16_t load_shuffle() { From 4102aa998c2bc55423e9957ad70bee662e6f0b1c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 09:17:32 +0300 Subject: [PATCH 11/33] New iq4_kt: slightly faster NEON --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 48 ++++++++++++++++++------------ 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 4e06b1b8f..0a17fe838 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1330,7 +1330,7 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& union { uint32x4x2_t vec; uint32_t val[8]; } o_helper; - constexpr int k_acc = nrc_y; + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; float32x4_t accd[k_acc]; @@ -1339,11 +1339,11 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); } - uint32_t values[64]; - int8x16x2_t xv[4]; + uint32_t values[16]; + int8x16x2_t xv[8]; int32x4x4_t dot; - auto compute_dot = [&dot, &xv] (const int8_t * y) { + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { for (int k = 0; k < 4; ++k) { auto yv = vld1q_s8_x2(y + 32*k); dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); @@ -1379,27 +1379,37 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int j = 0; j < 4; ++j) { const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; } + xv[ib+0] = trellis.next32(values+0); + xv[ib+4] = trellis.next32(values+8); } - for (int i128 = 0; i128 < 2; ++i128) { - for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); - for (int iy = 0; iy < nrc_y; ++iy) { - const block_q8_0_x4& yb = y[iy][2*i+i128]; - auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d))); - //auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16)); - //dy = vmulq_f32(scales.val[i128], dy); - auto sumi = compute_dot(yb.qs); - accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi)); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); } } } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(accd[iy])); + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } } } } From 07d6e1d4b18f1aff6e703101a7d1227158eb7d37 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 10:05:49 +0300 Subject: [PATCH 12/33] New iq4_kt: faster NEON We are now at 9.4 t/s, up from 6.6 t/s for the f16 trellis. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 38 +++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0a17fe838..d4547d4a1 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1353,6 +1353,9 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& return vpaddq_s32(dot.val[0], dot.val[2]); }; + //int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}}; + int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}}; + float32x4x2_t scales; for (int ix = 0; ix < nrc_x; ++ix) { @@ -1376,14 +1379,33 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); for (int ib = 0; ib < 4; ++ib) { - for (int j = 0; j < 4; ++j) { - const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - } + auto vql1 = vmovl_u8(vld1_u8(ql+8*ib)); + auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32)); + auto vqh = vmovl_u8(vld1_u8(qh+8*ib)); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8))); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4))); + auto sh1_u32 = vdupq_n_u32(shb[ib+0]); + auto sh2_u32 = vdupq_n_u32(shb[ib+4]); + auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1]))); + auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1]))); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1)); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2)); + auto oh1 = vdupq_n_u32(o_helper.val[ib+0]); + auto oh2 = vdupq_n_u32(o_helper.val[ib+4]); + vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1)); + vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1)); + vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2)); + vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2)); + //auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts); + //auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts); + //for (int j = 0; j < 4; ++j) { + // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + // values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + // values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + // values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + // values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + //} xv[ib+0] = trellis.next32(values+0); xv[ib+4] = trellis.next32(values+8); } From d6ac52c0d7cac9d3c0007a6c0de6896981b4d286 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 10:29:40 +0300 Subject: [PATCH 13/33] Minor --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index d4547d4a1..ba35631bf 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1353,7 +1353,6 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& return vpaddq_s32(dot.val[0], dot.val[2]); }; - //int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}}; int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}}; float32x4x2_t scales; @@ -1396,16 +1395,6 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1)); vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2)); vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2)); - //auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts); - //auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts); - //for (int j = 0; j < 4; ++j) { - // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - // values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - // values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - // values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - // values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - //} xv[ib+0] = trellis.next32(values+0); xv[ib+4] = trellis.next32(values+8); } @@ -1443,23 +1432,16 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array Date: Sun, 8 Jun 2025 11:47:55 +0300 Subject: [PATCH 14/33] New iq4_kt trellis: not working Metal implementation --- ggml/src/ggml-metal.metal | 42 ++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a05a890e2..f850d998a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6596,6 +6596,25 @@ void kernel_mul_mv_iq2_k_f32_impl( } } +struct Trellis3 { + constexpr constant static uint32_t kmask = 0x3f3f3f3f; + constexpr constant static uint32_t ka = 89226354; + constexpr constant static uint32_t kb = 64248484; + constexpr constant static uint32_t ka1 = ka*ka; + constexpr constant static uint32_t kb1 = kb*ka+kb; + constexpr constant static uint32_t ka2 = ka1*ka; + constexpr constant static uint32_t kb2 = kb1*ka+kb; + constexpr constant static uint32_t ka3 = ka2*ka; + constexpr constant static uint32_t kb3 = kb2*ka+kb; + static inline char4 gen4(uint32_t val) { + thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask}; + thread const int8_t * a8 = (thread const int8_t *)aux; + char4 result; + for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; + return result; + } +}; + struct Trellis { constexpr constant static uint32_t kmask1 = 0x8fff8fff; constexpr constant static uint32_t kmask2 = 0x3b603b60; @@ -8586,20 +8605,20 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread device const uint32_t * shb = x->qs; device const uint8_t * ql = (device const uint8_t *)(shb + 8); device const uint8_t * qh = ql + 64; - float scale = d * (((shb[ib32] & 0xff) >> 1) - 64); + const int ls = (shb[ib32] & 0xff) >> 1; + const float scale = d * (ls - 64); const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); - const int jj = ib32*8 + 4*(il%2); - ql += jj; - qh += jj%32; + ql += 8*ib32; + qh += 8*(ib32%4); uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12; - const int shift = 8 - 4*(jj/32); + const int shift = 8 - 4*(ib32/4); for (int i = 0; i < 4; ++i) { uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset; - auto v = (float4)Trellis::gen4(idx); - reg[i] = v * scale; + auto c4 = Trellis3::gen4(idx); + reg[i] = {scale*c4[0], scale*c4[1], scale*c4[2], scale*c4[3]}; } } @@ -8931,18 +8950,17 @@ struct DequantizerKT4 { using type4x4 = T4x4; DequantizerKT4(device const char * cx, short il = 0) : il(il) { device const float * dptr = (device const float *)cx; - d[0] = dptr[0] * 31.75f * 1.01f; - d[1] = dptr[1]; + d = dptr[0] * 1.01f; x = (device const Block *)(dptr + 2); } inline void convert(thread T4x4& t) const { float4x4 tmp; - dequantize_iq4_kt(x, il, d[0], tmp); + dequantize_iq4_kt(x, il, d, tmp); for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void convert(int64_t ind, thread T4x4& t) { float4x4 tmp; - dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp); + dequantize_iq4_kt(x + ind/nl, ind%nl, d, tmp); for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void next() { @@ -8951,7 +8969,7 @@ struct DequantizerKT4 { } device const Block * x; short il; - float d[2]; + float d; }; template From be78290a23e38087273f1fb4add8235d19ed0175 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 13:36:47 +0300 Subject: [PATCH 15/33] Remove the extra 4 bytes of row meta data that is no longer used --- ggml/src/ggml-cuda/convert.cu | 2 +- ggml/src/ggml-cuda/iqk_mmvq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 4 ++-- ggml/src/ggml.c | 7 +------ ggml/src/iqk/iqk_gemm_ktquants.cpp | 4 ++-- ggml/src/iqk/iqk_quantize.cpp | 19 +++++-------------- 6 files changed, 12 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c6e064f6e..4dd053cfd 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -409,7 +409,7 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); float scale = dptr[0] * 1.00f; - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); const int64_t i = ii - (row*n_per_row)/QK_K; constexpr int kNumGroups = 64; diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 789670b56..5dcf51319 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -441,7 +441,7 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( constexpr uint32_t km = 0x3f3f3f3f; float scale = *(const float *)vbq; - const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + 2*sizeof(float)) + kbx; + const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx; // iqs is 0...28 const int ib32 = iqs/4; // Why iqs/4 ? diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 7d829c7e6..26a7933ca 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2819,7 +2819,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + 2*sizeof(float)) + kbx0; + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0; int ib32 = kqsx/4; int j = kqsx%4; @@ -2855,7 +2855,7 @@ template static __device__ __forceinlin } const float * dptr = (const float *)(x + i*stride); - const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 2) + kbx0; + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0; const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; #ifdef INT8_MMA_AVAILABLE diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b1b9b57bd..9375963ff 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1635,13 +1635,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif -//#ifdef __ARM_NEON -// .vec_dot_type = GGML_TYPE_F16, -//#else -// .vec_dot_type = GGML_TYPE_F32, -//#endif .nrows = 1, - .row_meta_size = 8, + .row_meta_size = 4, }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index ba35631bf..9c87373d4 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -383,7 +383,7 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, for (int k = 0; k < 8; ++k) { const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); dkt[k] = dptr[0]; - x8[k] = (const block_iq4_kt *)(dptr + 2); + x8[k] = (const block_iq4_kt *)(dptr + 1); } auto vd = _mm256_loadu_ps(dkt); @@ -512,7 +512,7 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = _mm256_set1_ps(dptr[0]); - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b45478298..a95bc22ba 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -8565,7 +8565,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f float * dptr = (float *)vy; - block_iq4_kt * y = (block_iq4_kt *)(dptr + 2); + block_iq4_kt * y = (block_iq4_kt *)(dptr + 1); auto& quantizer1 = iq4kt_quantizer(); auto& quantizer2 = iq4kt_quantizer(true); @@ -8574,16 +8574,10 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); - float amax_row = 0, row_av = 0; + float amax_row = 0; for (int j = 0; j < n_per_row; ++j) { - row_av += x[j]; amax_row = std::max(amax_row, std::abs(x[j])); } - //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - //row_av /= n_per_row; - row_av = 0; - //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - dptr[1] = row_av; if (!amax_row) { dptr[0] = 0.f; std::memset(y, 0, nblock*sizeof(block_iq4_kt)); @@ -8606,7 +8600,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; float amax = 0; for (int j = 0; j < Q::kBlockSize; ++j) { - xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + xaux[j] = xbl[ib*Q::kBlockSize+j]; float ax = std::abs(xaux[j]); amax = std::max(amax, ax); } @@ -8686,7 +8680,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f for (int ib = 0; ib < Q::kNblock; ++ib) { auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1; const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; - for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j]; int ls = nearest_int(id*scales[ib]); ls = std::min(ls, 63); *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1); @@ -8754,8 +8748,7 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { const int nb = k / Q::kSuperBlockSize; const float * dptr = (const float *)x; const float d = dptr[0] * Q::kScale; - const float row_av = dptr[1]; - x = (const block_iq4_kt *)(dptr + 2); + x = (const block_iq4_kt *)(dptr + 1); auto& deq = iq4kt_dequantizer(); for (int ibl = 0; ibl < nb; ++ibl) { auto shb = x[ibl].qs; @@ -8763,14 +8756,12 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { auto qh = ql + kNumGroups; for (int ib = 0; ib < Q::kNblock; ++ib) { int offset = shb[ib] & 1 ? 32768 + 4096 : 4096; - //auto& deq = shb[ib] & 1 ? deq2 : deq1; int ls = int((shb[ib] & 0xff) >> 1) - 64; float sl = d * ls; for (int ig = 0; ig < Q::kNg; ++ig) { int jj = ib*Q::kNg+ig; uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12); deq.set_values(idx, y, sl, offset); - for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av; y += Q::kGroupSize; } } From 6480fa5967bce55c3863d997bad14bcea105aacc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 13:44:20 +0300 Subject: [PATCH 16/33] Cleanup --- ggml/src/ggml-metal.metal | 2 +- ggml/src/iqk/iqk_gemm_ktquants.cpp | 36 ++---------------------------- 2 files changed, 3 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f850d998a..e1d474040 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -8951,7 +8951,7 @@ struct DequantizerKT4 { DequantizerKT4(device const char * cx, short il = 0) : il(il) { device const float * dptr = (device const float *)cx; d = dptr[0] * 1.01f; - x = (device const Block *)(dptr + 2); + x = (device const Block *)(dptr + 1); } inline void convert(thread T4x4& t) const { float4x4 tmp; diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 9c87373d4..965ecc2da 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1218,38 +1218,6 @@ struct Trellis3 { result.val[1] = vmlaq_u32(mkb, mka, result.val[1]); return result; } - //inline int8x16x2_t next32(const uint32_t * val) const { - // int8x16x4_t aux; - // int8x16x2_t result; - // for (int i = 0; i < 2; ++i) { - // auto i8 = next8(val[4*i+0], val[4*i+1]); - // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); - // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - // aux.val[0] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); - // aux.val[1] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); - // i8 = next8(val[4*i+2], val[4*i+3]); - // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); - // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - // aux.val[2] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); - // aux.val[3] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); - // result.val[i] = vqtbl4q_s8(aux, shuffle); - // } - // return result; - //} - // This works: - //inline int8x16x2_t next32(const uint32_t * val) const { - // uint16x8_t aux[4]; - // for (int i = 0; i < 4; ++i) { - // auto i8 = next8(val[2*i+0], val[2*i+1]); - // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); - // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - // auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); - // auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); - // aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); - // } - // int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; - // return result; - //} inline int8x16x2_t next32(const uint32_t * val) const { int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; for (int i = 0; i < 2; ++i) { @@ -1290,7 +1258,7 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, for (int k = 0; k < 8; ++k) { const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); dkt[k] = dptr[0]; - x8[k] = (const block_iq4_kt *)(dptr + 2); + x8[k] = (const block_iq4_kt *)(dptr + 1); } auto vd = vld1q_f32_x2(dkt); @@ -1360,7 +1328,7 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = vdupq_n_f32(dptr[0]); - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); From 41187e4a934bf8b1e24195baa0f2698110256b6c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 09:59:39 +0300 Subject: [PATCH 17/33] Adding forgottent file --- ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu new file mode 100644 index 000000000..4590f9193 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); From e095b0fa809783e93e3a1f8cbe9e40897ab1ab87 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 17:27:40 +0300 Subject: [PATCH 18/33] Switching iq2_kt to new trellis - CUDA MMQ --- ggml/src/ggml-cuda/convert.cu | 4 +- ggml/src/ggml-cuda/mmq.cu | 4 ++ ggml/src/ggml-cuda/mmq.cuh | 83 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_quantize.cpp | 59 +++++++++++++++++-------- 4 files changed, 130 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 4dd053cfd..cb1408e6e 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -374,9 +374,9 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f; for (int j = 0; j < 8; ++j) { - y[j] = dl * trellis_next(idx); + y[j] = dl * trellis_next_int(idx); } } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 21bf9003a..67fa335a2 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -103,6 +103,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KT: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ2_KT: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case(ctx, args, stream); break; @@ -176,6 +179,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ4_KT: + case GGML_TYPE_IQ2_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 26a7933ca..e2c76a854 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -93,6 +93,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ4_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: @@ -203,6 +204,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } @@ -252,6 +254,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } @@ -2866,6 +2869,78 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq2_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_kt * bxi = (const block_iq2_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val = ka*val + kb; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val = ka*val + kb; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.05f; + const block_iq2_kt * bxi = (const block_iq2_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = iq4k_values[(bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3465,6 +3540,13 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; @@ -3927,6 +4009,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index a95bc22ba..b9e9d7752 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -8006,7 +8006,7 @@ std::vector QuantizerIQKT; +using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16, false, true>; const QuantizerIQ2KT& iq2kt_quantizer() { static std::mutex mutex; @@ -8017,7 +8017,7 @@ const QuantizerIQ2KT& iq2kt_quantizer() { } void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights, - float * qtmp) { + int * all_idx) { constexpr float kSigmaScale = 2.0f; using Q = QuantizerIQ2KT; @@ -8036,6 +8036,11 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + float amax_row = 0; + for (int j = 0; j < n_per_row; ++j) { + amax_row = std::max(amax_row, std::abs(x[j])); + } + float amax_scale = 0, max_scale = 0; for (int ibl = 0; ibl < nblock; ++ibl) { @@ -8053,9 +8058,10 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float ax = std::abs(xb[j]); amax = std::max(amax, ax); } - quantizer.find_best_match( amax/96.f, xb, weight, best_idx); + float scale_0 = std::max(90.f, 124.f*amax/amax_row); + quantizer.find_best_match( amax/scale_0, xb, weight, best_idx); auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); - quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg); + quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx + Q::kNg); auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); auto idx = best_idx; @@ -8063,12 +8069,7 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f else { scales[ib] = dm; idx += Q::kNg; } - auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; - for (int ig = 0; ig < Q::kNg; ++ig) { - auto q = quantizer.values() + idx[ig]*Q::kGroupSize; - for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j]; - qt += Q::kGroupSize; - } + for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig]; float abs_scale = std::abs(scales[ib]); if (abs_scale > amax_scale) { @@ -8091,20 +8092,22 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float sumqx = 0, sumq2 = 0; for (int ibl = 0; ibl < nblock; ++ibl) { const float * xb = x + ibl*Q::kSuperBlockSize; - const float * qb = qtmp + ibl*Q::kSuperBlockSize; const float * wb = all_weights + ibl*Q::kSuperBlockSize; auto scales = all_scales + ibl*Q::kNblock; for (int ib = 0; ib < Q::kNblock; ++ib) { int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); float dl = iq4k_values[ls]; - for (int j = 0; j < Q::kBlockSize; ++j) { - float q = dl*qb[j]; - sumqx += wb[j]*xb[j]*q; - sumq2 += wb[j]*q*q; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto qb = quantizer.values() + Q::kGroupSize*all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig]; + for (int j = 0; j < Q::kGroupSize; ++j) { + int jj = ig*Q::kGroupSize + j; + float q = dl*qb[j]; + sumqx += wb[jj]*xb[jj]*q; + sumq2 += wb[jj]*q*q; + } } xb += Q::kBlockSize; wb += Q::kBlockSize; - qb += Q::kBlockSize; } } if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { @@ -8140,6 +8143,26 @@ void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const f float dl = d*ls; quantizer.find_best_match(dl, xb, weight, best_idx); + auto prev_idx = all_idx + (ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize; + + float mse1 = 0, mse2 = 0; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q1 = quantizer.values() + Q::kGroupSize*prev_idx[ig]; + auto q2 = quantizer.values() + Q::kGroupSize*best_idx[ig]; + for (int j = 0; j < Q::kGroupSize; ++j) { + int jj = ig*Q::kGroupSize + j; + float diff1 = xb[jj] - dl*q1[j]; + float diff2 = xb[jj] - dl*q2[j]; + mse1 += weight[jj]*diff1*diff1; + mse2 += weight[jj]*diff2*diff2; + } + } + if (mse1 < mse2) { + for (int ig = 0; ig < Q::kNg; ++ig) best_idx[ig] = prev_idx[ig]; + } else { + for (int ig = 0; ig < Q::kNg; ++ig) prev_idx[ig] = best_idx[ig]; + } + for (int j = 0; j < Q::kNg; ++j) { qs[j] = best_idx[j]; auto xl = xb + Q::kGroupSize*j; @@ -8207,10 +8230,10 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row); std::vector scales(n_per_row/QuantizerIQ2KT::kBlockSize); std::vector weights(n_per_row); - std::vector xtmp(n_per_row); + std::vector idx(n_per_row/QuantizerIQ2KT::kGroupSize); char * qrow = (char *)dst; for (int64_t row = 0; row < nrows; ++row) { - quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); + quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), idx.data()); src += n_per_row; qrow += row_size; } From 1efb3adc9bc0bcb82d5d5f97a6118a9773cbbd11 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 17:51:28 +0300 Subject: [PATCH 19/33] New iq2_kt: CUDA GEMV --- ggml/src/ggml-cuda/iqk_mmvq.cu | 43 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 ++++ ggml/src/ggml-cuda/mmvq.cu | 3 +++ 3 files changed, 51 insertions(+) diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 5dcf51319..c026ff07b 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -471,6 +471,41 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq2_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf]; + const float dl = scale * ls * 1.05f; + auto ql = (const uint16_t *)bq2->ql; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = ql[4*ib32+j] + 4096; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val = ka*val + kb; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val = ka*val + kb; + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ4_KSS_Q8_1_MMVQ 4 #define VDR_IQ4_KSS_Q8_1_MMQ 4 @@ -1263,6 +1298,14 @@ void mul_mat_vec_iq4_kt_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq2_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 97c172f33..a77bef543 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -105,3 +105,8 @@ void mul_mat_vec_iq4_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq2_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 6c230050d..19a72afa5 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -529,6 +529,8 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_KT: mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ2_KT: + mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); @@ -691,6 +693,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: case GGML_TYPE_IQ4_KT: + case GGML_TYPE_IQ2_KT: return true; default: return false; From 8db83dac1ddec2f7fc13cf2dbdc620bcb2ed00e9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 09:00:56 +0300 Subject: [PATCH 20/33] New iq2_kt: AVX2 dequantize --- ggml/src/ggml.c | 11 ++- ggml/src/iqk/iqk_gemm_ktquants.cpp | 119 ++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 4 +- 3 files changed, 126 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9375963ff..edf25e35a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1596,11 +1596,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq2_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref, .vec_dot = vec_dot_iq2_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif +//#ifdef __ARM_NEON +// .vec_dot_type = GGML_TYPE_F16, +//#else +// .vec_dot_type = GGML_TYPE_F32, +//#endif .nrows = 1, .row_meta_size = 4, }, diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 965ecc2da..5c971b316 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -98,6 +98,7 @@ struct Trellis2 { }; +template struct Trellis3 { constexpr static uint32_t ka = 89226354; constexpr static uint32_t kb = 64248484; @@ -107,14 +108,26 @@ struct Trellis3 { constexpr static uint32_t kb2 = kb1*ka+kb; constexpr static uint32_t ka3 = ka2*ka; constexpr static uint32_t kb3 = kb2*ka+kb; - const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); - const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); + constexpr static uint32_t ka4 = ka3*ka; + constexpr static uint32_t kb4 = kb3*ka+kb; + constexpr static uint32_t ka5 = ka4*ka; + constexpr static uint32_t kb5 = kb4*ka+kb; + constexpr static uint32_t ka6 = ka5*ka; + constexpr static uint32_t kb6 = kb5*ka+kb; + constexpr static uint32_t ka7 = ka6*ka; + constexpr static uint32_t kb7 = kb6*ka+kb; + const __m256i mka = is_8 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); + const __m256i mkb = is_8 ? _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7) : _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); inline __m256i next8(uint32_t val1, uint32_t val2) const { __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); } + inline __m256i next8(uint32_t val) const { + __m256i mval = _mm256_set1_epi32(val); + return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + } inline __m256 gen8(uint32_t val1, uint32_t val2) const { auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); #ifdef HAVE_FANCY_SIMD @@ -122,6 +135,16 @@ struct Trellis3 { #else auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + return _mm256_cvtepi32_ps(i8); + } + inline __m256 gen8(uint32_t val) const { + auto v8 = _mm256_and_si256(next8(val), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); +#else + auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); + auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); #endif return _mm256_cvtepi32_ps(i8); } @@ -144,6 +167,34 @@ struct Trellis3 { // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 return _mm256_permutevar8x32_epi32(aux[0], shuffle); } + template + inline void next64(const uint32_t * val, __m256i * result) const { + const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); + auto vka3 = _mm256_set1_epi32(ka3), vkb3 = _mm256_set1_epi32(kb3); + __m256i aux[8]; + for (int i = 0; i < 4; ++i) { + auto i8_1 = next8(val[2*i+0], val[2*i+1]); + auto i8_2 = _mm256_add_epi32(_mm256_mullo_epi32(i8_1, vka3), vkb3); + i8_1 = _mm256_and_si256(i8_1, _mm256_set1_epi32(0x3f3f3f3f)); + i8_2 = _mm256_and_si256(i8_2, _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + aux[i+0] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8_1); + aux[i+4] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8_2); +#else + auto dot1 = _mm256_maddubs_epi16(i8_1, _mm256_set1_epi32(0x01010101)); + auto dot2 = _mm256_maddubs_epi16(i8_2, _mm256_set1_epi32(0x01010101)); + aux[i+0] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot1, _mm256_set1_epi16(1))); + aux[i+4] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot2, _mm256_set1_epi16(1))); +#endif + } + for (int k = 0; k < 2; ++k) { + aux[4*k+0] = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[4*k+2] = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[4*k+0] = _mm256_packs_epi16(aux[4*k+0], aux[4*k+2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + result[k] = _mm256_permutevar8x32_epi32(aux[4*k+0], shuffle); + } + } }; void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { @@ -185,6 +236,60 @@ void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t } } +void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq2_kt * x8[8]; + float dkt[8]; + float ls[8]; + float ls_all[64]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq2_kt *)(dptr + 1); + } + auto vd = _mm256_mul_ps(_mm256_set1_ps(1.05f), _mm256_loadu_ps(dkt)); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto s8 = _mm_set1_epi32(*(const uint32_t *)x8[k][i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32)); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls)); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + } + __m256i packed[2]; + trellis.next64(idx, packed); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]); + } + } + y += 8; // = QK_K/32; + } + } +} + template void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -655,6 +760,14 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; + case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } From c9800ae62e01d3343caf8cce8ffab24f4bb06112 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 09:58:39 +0300 Subject: [PATCH 21/33] New iq2_kt: AVX2 GEMM/GEMV --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 120 +++++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 5c971b316..638cbb930 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -159,6 +159,25 @@ struct Trellis3 { #else auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101)); aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif + } + aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } + template + inline __m256i next32(const uint16_t * val, uint32_t v0) const { + const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); + __m256i aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = _mm256_and_si256(next8(v0 + val[i]), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD + aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); +#else + auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); #endif } aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -352,6 +371,93 @@ void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +template +void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + //auto m126 = _mm256_set1_ps(-126.f); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + for (int i128 = 0; i128 < 2; ++i128) { + //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + inline __m256 abs_ps(__m256 vals) { // Clear sign-bit of all the 32-bit floats in vals __m256 sign_bit = _mm256_set1_ps(-0.0f); @@ -760,13 +866,13 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array Date: Mon, 9 Jun 2025 10:26:27 +0300 Subject: [PATCH 22/33] Adding forgotten file --- ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu new file mode 100644 index 000000000..2d48f077a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); From 08067aa7a7349a2e01f51631795f9ec1fd4d8d09 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 12:20:42 +0300 Subject: [PATCH 23/33] New iq2_kt: NEON GEMM/GEMV --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 186 ++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 2 files changed, 186 insertions(+), 2 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 638cbb930..277aa45a6 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1452,6 +1452,51 @@ struct Trellis3 { } return result; } + inline int8x16x2_t next32(const uint16_t * val, uint32_t v0) const { + auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3); + int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; + int8x16x2_t i8; + for (int i = 0; i < 2; ++i) { + i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+0]+v0)); + i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+1]+v0)); + i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); + } + return result; + } + inline int8x16x4_t next64(const uint32_t * val) const { + auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3); + int8x16x4_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126)}; + for (int i = 0; i < 2; ++i) { + auto i8_1 = next8(val[4*i+0], val[4*i+1]); + int8x16x2_t i8_2{vmlaq_u32(vkb3, vka3, i8_1.val[0]), vmlaq_u32(vkb3, vka3, i8_1.val[1])}; + i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1])); + auto s1_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1])); + i8_1 = next8(val[4*i+2], val[4*i+3]); + i8_2.val[0] = vmlaq_u32(vkb3, vka3, i8_1.val[0]); + i8_2.val[1] = vmlaq_u32(vkb3, vka3, i8_1.val[1]); + i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s2_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1])); + auto s2_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1])); + result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1)); + result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2)); + } + return result; + } static uint8x16_t load_shuffle() { static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; return vld1q_u8(k_shuffle); @@ -1612,6 +1657,136 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& } } +void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq2_kt * x8[8]; + float dkt[8]; + float ls[8], ls_all[64]; + uint32_t idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0] * 1.05f; + x8[k] = (const block_iq2_kt *)(dptr + 1); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto u32 = *(const uint32_t *)x8[k][i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32)); + auto s16 = vmovl_s8(s8); + vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16)))); + vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0)); + auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4)); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + } + vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx)); + } + } + y += 8; // = QK_K/32; + } + } +} + +template +void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto values = vld1q_s8(iq4k_values); + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + int8x16x2_t xv[8]; + int32x4x4_t dot; + + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]*1.05f); + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto u32 = *(const uint32_t *)x[i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32)); + auto s16 = vmovl_s8(s8); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16)))); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); + const uint16_t * ql = (const uint16_t *)x[i].ql; + for (int k = 0; k < 8; ++k) xv[k] = trellis.next32(ql + 4*k, 4096); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); + } + } + } + + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } + } +} + } bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { @@ -1628,6 +1803,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_F16 : type; + case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; From d075a1c75b52ee9d45e0c01278757ad7fb64d9fb Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 12:58:27 +0300 Subject: [PATCH 24/33] New iq2_kt: slightly faster NEON GEMM --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 277aa45a6..8b8cae14d 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1759,6 +1759,30 @@ void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16)))); scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16)))); const uint16_t * ql = (const uint16_t *)x[i].ql; + if constexpr (nrc_y == 1) { + const block_q8_0_x4& ybl = y[0][2*i+0]; + const block_q8_0_x4& ybh = y[0][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + int32x4x4_t suml = {}; + int32x4x4_t sumh = {}; + for (int ib = 0; ib < 4; ++ib) { + auto xl = trellis.next32(ql + 4*ib + 0, 4096); + auto xh = trellis.next32(ql + 4*ib + 16, 4096); + auto yl = vld1q_s8_x2(ybl.qs + 32*ib); + auto yh = vld1q_s8_x2(ybh.qs + 32*ib); + suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]); + sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]); + } + auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]); + auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]); + auto sl = vpaddq_s32(sl1, sl2); + auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]); + auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]); + auto sh = vpaddq_s32(sh1, sh2); + accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl)); + accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh)); + } else { for (int k = 0; k < 8; ++k) xv[k] = trellis.next32(ql + 4*k, 4096); for (int iy = 0; iy < nrc_y; ++iy) { const block_q8_0_x4& ybl = y[iy][2*i+0]; @@ -1775,6 +1799,7 @@ void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); } } + } } if constexpr (nrc_y == 1) { From f2be982fd8d0f7a2176c0437640ade4225ad5932 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 15:04:23 +0300 Subject: [PATCH 25/33] New iq2_kt: Metal - very slow. It seems Apple Silicon cannot quickly add 4 8-bit ints. Or I don't know how to do it - but I didn't find anything in the Metal Shading Language Specification. So, performance is quite a bit worse than the original trellis. --- ggml/src/ggml-metal.metal | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e1d474040..c3c4f0bb4 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6613,6 +6613,18 @@ struct Trellis3 { for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; return result; } + template + static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) { + thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3}; + uint32_t aux32[2]; + thread const int8_t * a8 = (thread const int8_t *)aux32; + for (int i = 0; i < 4; ++i) { + aux32[0] = aux[i] & kmask; + aux32[1] = (ka3*aux[i] + kb3) & kmask; + v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3]; + v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7]; + } + } }; struct Trellis { @@ -6710,7 +6722,7 @@ void kernel_mul_mv_iq2_kt_f32_impl( float drow[N_DST]; for (int row = 0; row < N_DST; ++row) { device const float * dptr = (device const float *)(cx + row*row_size); - drow[row] = dptr[0] * 31.75f * 1.05f; + drow[row] = dptr[0] * 1.05f; } device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float)); @@ -6725,10 +6737,10 @@ void kernel_mul_mv_iq2_kt_f32_impl( const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf]; - Trellis::gen8(q2[2*it+0]+4096, v1, v2); + Trellis3::gen8(q2[2*it+0]+4096, v1, v2); auto sum = v1*y4[0] + v2*y4[1]; - Trellis::gen8(q2[2*it+1]+4096, v1, v2); + Trellis3::gen8(q2[2*it+1]+4096, v1, v2); sum += v1*y4[2] + v2*y4[3]; sum *= ls; @@ -8561,19 +8573,18 @@ template void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 int ib32 = il/2; - half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h; + half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 1.05h; device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); - half4 v1, v2; + char4 v1, v2; for (int i = 0; i < 2; ++i) { - Trellis::gen8(q2[i]+4096, v1, v2); - v1 *= scale; v2 *= scale; + Trellis3::gen8(q2[i]+4096, v1, v2); if constexpr (is_same_v) { - reg[2*i+0] = v1; - reg[2*i+1] = v2; + reg[2*i+0] = {scale*(half)v1[0], scale*(half)v1[1], scale*(half)v1[2], scale*(half)v1[3]}; + reg[2*i+1] = {scale*(half)v2[0], scale*(half)v2[1], scale*(half)v2[2], scale*(half)v2[3]}; } else { - reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]}; - reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]}; + reg[2*i+0] = {scale*(float)v1[0], scale*(float)v1[1], scale*(float)v1[2], scale*(float)v1[3]}; + reg[2*i+1] = {scale*(float)v2[0], scale*(float)v2[1], scale*(float)v2[2], scale*(float)v2[3]}; } } } From 57e882fd84aadaa0540e18a58c01a2f14c6f680c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 15:18:11 +0300 Subject: [PATCH 26/33] Add missing break --- ggml/src/ggml-cuda/mmvq.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 19a72afa5..76126d768 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -526,6 +526,7 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm break; case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ4_KT: mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; From de4e6c797f12874b8221a3776b3601bf6d7420cd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 13 Jun 2025 19:38:11 +0300 Subject: [PATCH 27/33] Trying @louiehelm's multiplier --- ggml/src/ggml-cuda/convert.cu | 5 ++--- ggml/src/ggml-cuda/iqk_mmvq.cu | 12 +++++------- ggml/src/ggml-cuda/mmq.cuh | 14 ++++++-------- ggml/src/iqk/iqk_quantize.cpp | 7 ++++--- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index cb1408e6e..7fddcc895 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -341,9 +341,8 @@ inline __device__ int nearest_int(float fval) { } int __device__ __forceinline__ trellis_next_int(uint32_t& val) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; - val = ka*val + kb; + constexpr uint32_t ka = 0xCBAC1FED; + val = ka*val; return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index c026ff07b..bec6a739d 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -436,8 +436,7 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; + constexpr uint32_t ka = 0xCBAC1FED; constexpr uint32_t km = 0x3f3f3f3f; float scale = *(const float *)vbq; @@ -461,7 +460,7 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0; int v4 = 0; for (int k = 0; k < 4; ++k) { - val = ka*val + kb; + val *= ka; //int s = val & km; //sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126); v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; @@ -474,8 +473,7 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( __device__ __forceinline__ void vec_dot_iq2_kt_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; + constexpr uint32_t ka = 0xCBAC1FED; constexpr uint32_t km = 0x3f3f3f3f; float scale = *(const float *)vbq; @@ -492,13 +490,13 @@ __device__ __forceinline__ void vec_dot_iq2_kt_q8_1( uint32_t val = ql[4*ib32+j] + 4096; int v4 = 0; for (int k = 0; k < 4; ++k) { - val = ka*val + kb; + val *= ka; v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; } sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); v4 = 0; for (int k = 0; k < 4; ++k) { - val = ka*val + kb; + val *= ka; v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; } sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index e2c76a854..a7a6f5e5b 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2799,8 +2799,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void load_tiles_iq4_kt( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; + constexpr uint32_t ka = 0xCBAC1FED; constexpr uint32_t km = 0x3f3f3f3f; #ifdef INT8_MMA_AVAILABLE @@ -2835,8 +2834,8 @@ template static __device__ __forceinlin uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9); int2 v = {0, 0}; for (int k = 0; k < 4; ++k) { - val1 = ka*val1 + kb; - val2 = ka*val2 + kb; + val1 *= ka; + val2 *= ka; v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k; v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k; } @@ -2872,8 +2871,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void load_tiles_iq2_kt( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; + constexpr uint32_t ka = 0xCBAC1FED; constexpr uint32_t km = 0x3f3f3f3f; #ifdef INT8_MMA_AVAILABLE @@ -2903,11 +2901,11 @@ template static __device__ __forceinlin uint32_t val = ql[4*ib32+j] + 4096; int2 v = {0, 0}; for (int k = 0; k < 4; ++k) { - val = ka*val + kb; + val *= ka; v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; } for (int k = 0; k < 4; ++k) { - val = ka*val + kb; + val *= ka; v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; } #ifdef INT8_MMA_AVAILABLE diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b9e9d7752..65cd1c3eb 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7419,18 +7419,19 @@ class QuantizerIQKT { inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const; static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; uint32_t x = i + offset; if constexpr (is_int) { + constexpr uint32_t ka = 0xCBAC1FED; uint32_t s; auto i8 = (const int8_t *)&s; for (int k = 0; k < kGroupSize; ++k) { - x = ka*x + kb; + x = ka*x; s = x & 0x3f3f3f3f; result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); } } else { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; constexpr uint32_t kmask = 0x8fff8fff; constexpr uint32_t km32 = 0x3b603b60; for (int k = 0; k < kGroupSize; ++k) { From 6d38e43f1da728f22015535cd4a18e9da1f253fb Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 14 Jun 2025 06:19:18 +0300 Subject: [PATCH 28/33] CPU --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 34 +++++++++--------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 8b8cae14d..fc954b545 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -100,33 +100,24 @@ struct Trellis2 { template struct Trellis3 { - constexpr static uint32_t ka = 89226354; - constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka = 0xCBAC1FED; constexpr static uint32_t ka1 = ka*ka; - constexpr static uint32_t kb1 = kb*ka+kb; constexpr static uint32_t ka2 = ka1*ka; - constexpr static uint32_t kb2 = kb1*ka+kb; constexpr static uint32_t ka3 = ka2*ka; - constexpr static uint32_t kb3 = kb2*ka+kb; constexpr static uint32_t ka4 = ka3*ka; - constexpr static uint32_t kb4 = kb3*ka+kb; constexpr static uint32_t ka5 = ka4*ka; - constexpr static uint32_t kb5 = kb4*ka+kb; constexpr static uint32_t ka6 = ka5*ka; - constexpr static uint32_t kb6 = kb5*ka+kb; constexpr static uint32_t ka7 = ka6*ka; - constexpr static uint32_t kb7 = kb6*ka+kb; const __m256i mka = is_8 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); - const __m256i mkb = is_8 ? _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7) : _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); inline __m256i next8(uint32_t val1, uint32_t val2) const { __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); - return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_mullo_epi32(mval, mka); } inline __m256i next8(uint32_t val) const { __m256i mval = _mm256_set1_epi32(val); - return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_mullo_epi32(mval, mka); } inline __m256 gen8(uint32_t val1, uint32_t val2) const { auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); @@ -189,11 +180,11 @@ struct Trellis3 { template inline void next64(const uint32_t * val, __m256i * result) const { const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); - auto vka3 = _mm256_set1_epi32(ka3), vkb3 = _mm256_set1_epi32(kb3); + auto vka3 = _mm256_set1_epi32(ka3); __m256i aux[8]; for (int i = 0; i < 4; ++i) { auto i8_1 = next8(val[2*i+0], val[2*i+1]); - auto i8_2 = _mm256_add_epi32(_mm256_mullo_epi32(i8_1, vka3), vkb3); + auto i8_2 = _mm256_mullo_epi32(i8_1, vka3); i8_1 = _mm256_and_si256(i8_1, _mm256_set1_epi32(0x3f3f3f3f)); i8_2 = _mm256_and_si256(i8_2, _mm256_set1_epi32(0x3f3f3f3f)); #ifdef HAVE_FANCY_SIMD @@ -1419,22 +1410,17 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } struct Trellis3 { - constexpr static uint32_t ka = 89226354; - constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka = ;0xCBAC1FED; constexpr static uint32_t ka1 = ka*ka; - constexpr static uint32_t kb1 = kb*ka+kb; constexpr static uint32_t ka2 = ka1*ka; - constexpr static uint32_t kb2 = kb1*ka+kb; constexpr static uint32_t ka3 = ka2*ka; - constexpr static uint32_t kb3 = kb2*ka+kb; const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3}; - const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3}; const uint8x16_t shuffle = load_shuffle(); inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const { uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)}; - result.val[0] = vmlaq_u32(mkb, mka, result.val[0]); - result.val[1] = vmlaq_u32(mkb, mka, result.val[1]); + result.val[0] = vmulq_u32(mka, result.val[0]); + result.val[1] = vmulq_u32(mka, result.val[1]); return result; } inline int8x16x2_t next32(const uint32_t * val) const { @@ -1457,12 +1443,12 @@ struct Trellis3 { int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; int8x16x2_t i8; for (int i = 0; i < 2; ++i) { - i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+0]+v0)); + i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+0]+v0)); i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); - i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+1]+v0)); + i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+1]+v0)); i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); From 32ff1f956f391483327b2eacae9adbf9dcc73b80 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 16 Jun 2025 16:57:16 +0300 Subject: [PATCH 29/33] iq3_kt: use integer trellis + CUDA dequantize and MMVQ --- ggml/src/ggml-cuda/common.cuh | 21 +++++++++----- ggml/src/ggml-cuda/convert.cu | 4 +-- ggml/src/ggml-cuda/iqk_mmvq.cu | 50 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 9 ++++-- ggml/src/ggml-cuda/mmvq.cu | 12 +++++--- ggml/src/iqk/iqk_quantize.cpp | 10 +++++-- 6 files changed, 88 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 291378f42..a0cdab287 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -578,6 +578,20 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; @@ -648,13 +662,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -template<> -struct ggml_cuda_type_traits { - static constexpr int qk = QK_K; - static constexpr int qr = QR4_XS; - static constexpr int qi = QI4_XS; -}; - ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 7fddcc895..b40079a3a 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -394,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f; + const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f; uint8_t mask = 1 << (ib/4); for (int j = 0; j < 8; ++j) { - y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); + y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); } } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index bec6a739d..c19215de3 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -504,6 +504,48 @@ __device__ __forceinline__ void vec_dot_iq2_kt_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq3_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + const float dl = scale * ls * 1.015f; + auto ql = (const uint16_t *)bq3->ql; + uint32_t mask = 0x01010101 << ib32; + const uint32_t * qh = (const uint32_t *)bq3->qh; + int sumi = 0; + for (int j = 0; j < 4; ++j) { + uint32_t val = ql[4*ib32+j] + 4096; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)); + v4 |= q << 8*k; + } + uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0); + v4 = __vsub4(v4 ^ signs, signs); + sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi); + v4 = 0; + for (int k = 0; k < 4; ++k) { + val *= ka; + int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)); + v4 |= q << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v4 = __vsub4(v4 ^ signs, signs); + sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ4_KSS_Q8_1_MMVQ 4 #define VDR_IQ4_KSS_Q8_1_MMQ 4 @@ -1304,6 +1346,14 @@ void mul_mat_vec_iq2_kt_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq3_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index a77bef543..e7c6e1d29 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -101,12 +101,17 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); -void mul_mat_vec_iq4_kt_q8_1_cuda( +void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); -void mul_mat_vec_iq2_kt_q8_1_cuda( +void mul_mat_vec_iq3_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq4_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 76126d768..6412be30b 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -527,12 +527,15 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; - case GGML_TYPE_IQ4_KT: - mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); - break; case GGML_TYPE_IQ2_KT: mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; + case GGML_TYPE_IQ3_KT: + mul_mat_vec_iq3_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_KT: + mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -693,8 +696,9 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: - case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: return true; default: return false; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 65cd1c3eb..0384e49ab 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7427,7 +7427,11 @@ class QuantizerIQKT { for (int k = 0; k < kGroupSize; ++k) { x = ka*x; s = x & 0x3f3f3f3f; - result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + if constexpr (is_abs) { + result[k] = scale*std::abs(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } else { + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } } } else { constexpr uint32_t ka = 89226354; @@ -8289,7 +8293,7 @@ void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace { -using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>; +using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true, true>; const QuantizerIQ3KT& iq3kt_quantizer() { static std::mutex mutex; std::lock_guard lock(mutex); @@ -8500,7 +8504,7 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) { #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return; #endif using Q = QuantizerIQ3KT; constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; From f6fa5652a322e75a19bd4784b1919af1546ad10b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 16 Jun 2025 17:21:17 +0300 Subject: [PATCH 30/33] iq3_kt: MMQ --- ggml/src/ggml-cuda/mmq.cu | 6 +- ggml/src/ggml-cuda/mmq.cuh | 88 +++++++++++++++++++ .../template-instances/mmq-instance-iq3_kt.cu | 5 ++ 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 67fa335a2..9103e3f1b 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -106,6 +106,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ2_KT: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ3_KT: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case(ctx, args, stream); break; @@ -178,8 +181,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: - case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a7a6f5e5b..c6e6a3651 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -94,6 +94,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: @@ -205,6 +206,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } @@ -255,6 +257,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } @@ -2939,6 +2942,83 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq3_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + const auto qh = (const uint32_t *)bxi->qh; + uint32_t mask = 0x01010101 << ib32; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + auto signs = __vcmpne4(qh[2*j+0] & mask, 0); + v.x = __vsub4(v.x ^ signs, signs); + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v.y = __vsub4(v.y ^ signs, signs); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.01f; + const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3545,6 +3625,13 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; @@ -4008,6 +4095,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu new file mode 100644 index 000000000..978bc6ca0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); From 6153d0e7946e63226dcbdc8f2eb9d9ec12e54cba Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 16 Jun 2025 18:21:33 +0300 Subject: [PATCH 31/33] iq3_kt: AVX2 GEMM --- ggml/src/ggml.c | 16 ++--- ggml/src/iqk/iqk_gemm_ktquants.cpp | 108 +++++++++++++++++++++++++---- 2 files changed, 103 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index edf25e35a..cc056f89a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1601,11 +1601,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif -//#ifdef __ARM_NEON -// .vec_dot_type = GGML_TYPE_F16, -//#else -// .vec_dot_type = GGML_TYPE_F32, -//#endif .nrows = 1, .row_meta_size = 4, }, @@ -1618,11 +1613,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq3_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref, .vec_dot = vec_dot_iq3_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif +//#ifdef __ARM_NEON +// .vec_dot_type = GGML_TYPE_F16, +//#else +// .vec_dot_type = GGML_TYPE_F32, +//#endif .nrows = 1, .row_meta_size = 4, }, diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index fc954b545..630c2e27d 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -98,7 +98,7 @@ struct Trellis2 { }; -template +template struct Trellis3 { constexpr static uint32_t ka = 0xCBAC1FED; constexpr static uint32_t ka1 = ka*ka; @@ -127,7 +127,11 @@ struct Trellis3 { auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); #endif - return _mm256_cvtepi32_ps(i8); + if constexpr (is_abs) { + return _mm256_cvtepi32_ps(_mm256_sign_epi32(i8, i8)); + } else { + return _mm256_cvtepi32_ps(i8); + } } inline __m256 gen8(uint32_t val) const { auto v8 = _mm256_and_si256(next8(val), _mm256_set1_epi32(0x3f3f3f3f)); @@ -137,11 +141,14 @@ struct Trellis3 { auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); #endif - return _mm256_cvtepi32_ps(i8); + if constexpr (is_abs) { + return _mm256_cvtepi32_ps(_mm256_sign_epi32(i8, i8)); + } else { + return _mm256_cvtepi32_ps(i8); + } } - template inline __m256i next32(const uint32_t * val) const { - const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); + const __m256i offset = _mm256_set1_epi32(-126); __m256i aux[4]; for (int i = 0; i < 4; ++i) { auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); @@ -156,11 +163,15 @@ struct Trellis3 { aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - return _mm256_permutevar8x32_epi32(aux[0], shuffle); + if constexpr (is_abs) { + auto result = _mm256_permutevar8x32_epi32(aux[0], shuffle); + return _mm256_sign_epi8(result, result); + } else { + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } } - template inline __m256i next32(const uint16_t * val, uint32_t v0) const { - const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); + const __m256i offset = _mm256_set1_epi32(-126); __m256i aux[4]; for (int i = 0; i < 4; ++i) { auto i8 = _mm256_and_si256(next8(v0 + val[i]), _mm256_set1_epi32(0x3f3f3f3f)); @@ -175,11 +186,15 @@ struct Trellis3 { aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - return _mm256_permutevar8x32_epi32(aux[0], shuffle); + if constexpr (is_abs) { + auto result = _mm256_permutevar8x32_epi32(aux[0], shuffle); + return _mm256_sign_epi8(result, result); + } else { + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } } - template inline void next64(const uint32_t * val, __m256i * result) const { - const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); + const __m256i offset = _mm256_set1_epi32(-126); auto vka3 = _mm256_set1_epi32(ka3); __m256i aux[8]; for (int i = 0; i < 4; ++i) { @@ -203,6 +218,9 @@ struct Trellis3 { aux[4*k+0] = _mm256_packs_epi16(aux[4*k+0], aux[4*k+2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 result[k] = _mm256_permutevar8x32_epi32(aux[4*k+0], shuffle); + if constexpr (is_abs) { + result[k] = _mm256_sign_epi8(result[k], result[k]); + } } } }; @@ -449,6 +467,70 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& } } +void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq3_kt * x8[8]; + float dkt[8]; + float ls[8]; + float ls_all[64]; + uint32_t idx[8]; + uint32_t sign_bits[16]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq3_kt *)(dptr + 1); + } + auto vd = _mm256_mul_ps(_mm256_set1_ps(1.01f), _mm256_loadu_ps(dkt)); + + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto s8 = _mm_set1_epi32(*(const uint32_t *)x8[k][i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + _mm256_storeu_ps(ls_all + 8*k, _mm256_cvtepi32_ps(s32)); + } + auto mask = _mm256_set1_epi8(1); + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib]; + auto scales = _mm256_mul_ps(vd, _mm256_loadu_ps(ls)); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 8; ++k) { + const uint16_t * ql = (const uint16_t *)x8[k][i].ql; + idx[k] = ql[4*ib+j] + 4096; + auto qh = (const uint32_t *)x8[k][i].qh; + sign_bits[k+0] = qh[2*j+0]; + sign_bits[k+8] = qh[2*j+1]; + } + __m256i packed[2]; + trellis.next64(idx, packed); + auto signs1 = _mm256_loadu_si256((const __m256i *)sign_bits+0); + auto signs2 = _mm256_loadu_si256((const __m256i *)sign_bits+1); + signs1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs1, mask), mask), _mm256_set1_epi8(1)); + signs2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs2, mask), mask), _mm256_set1_epi8(1)); + packed[0] = _mm256_sign_epi8(packed[0], signs1); + packed[1] = _mm256_sign_epi8(packed[1], signs2); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, packed[0]); + _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, packed[1]); + } + mask = _mm256_slli_epi16(mask, 1); + } + y += 8; // = QK_K/32; + } + } +} + inline __m256 abs_ps(__m256 vals) { // Clear sign-bit of all the 32-bit floats in vals __m256 sign_bit = _mm256_set1_ps(-0.0f); @@ -887,10 +969,10 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array Date: Mon, 16 Jun 2025 18:41:09 +0300 Subject: [PATCH 32/33] iq3_kt: AVX2 GEMV --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 94 ++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 630c2e27d..2ddfbe86f 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -531,6 +531,92 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } +template +void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis3 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[4], sv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv, &sv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], _mm256_sign_epi8(yv, sv[k])); +#else + auto p = _mm256_maddubs_epi16(xv[k], _mm256_sign_epi8(yv, sv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif + } + }; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0] * 1.01f); + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto ql = (const uint16_t *)x[i].ql; + auto sign_bits = _mm256_loadu_si256((const __m256i *)x[i].qh); + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(s32)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + auto mask = _mm256_set1_epi8(1); + for (int i128 = 0; i128 < 2; ++i128) { + for (int k = 0; k < 4; ++k) { + xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096); + sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1)); + mask = _mm256_slli_epi16(mask, 1); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + inline __m256 abs_ps(__m256 vals) { // Clear sign-bit of all the 32-bit floats in vals __m256 sign_bit = _mm256_set1_ps(-0.0f); @@ -947,6 +1033,14 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array Date: Mon, 16 Jun 2025 19:18:16 +0300 Subject: [PATCH 33/33] The trellis quants now need super-blocks of 256, so we need a check --- src/llama.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama.cpp b/src/llama.cpp index dfd533373..af8ef9be2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18627,6 +18627,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| + new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT || new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4) { if (nx % QK_K != 0) { LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));