From eb61f498d177f7acdc89008f54884cd3e0580131 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 12:30:37 +0300 Subject: [PATCH 01/17] New iq4_kt trellis The new trellis generates int8_t values via sum_as_uint8_t[(ka * idx + kb) & 0x3f33f3f3f] - 126. CUDA dequantize works. AVX2 case Ny > 32 works, and we get 273 t/s for L3-8B. PPL is on par or even slightly lower than original QTIP trellis. --- ggml/src/ggml-cuda/convert.cu | 12 +-- ggml/src/ggml.c | 11 +- ggml/src/iqk/iqk_gemm_ktquants.cpp | 157 +++++++++++++++++++++++++++-- ggml/src/iqk/iqk_mul_mat.cpp | 6 +- ggml/src/iqk/iqk_quantize.cpp | 29 ++++-- 5 files changed, 181 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b2f77e094..cba8d66b9 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -343,13 +343,9 @@ inline __device__ int nearest_int(float fval) { float __device__ __forceinline__ trellis_next(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; - uint32_t s; - const half * h = (const half *)&s; val = ka*val + kb; - s = (val & kmask) ^ km32; - return (float)(h[0]+h[1]); + //return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, 0x82828282); + return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126); } template @@ -367,7 +363,7 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f; for (int j = 0; j < 8; ++j) { y[j] = dl * trellis_next(idx); } @@ -401,7 +397,7 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t ii = blockIdx.x; int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); - float scale = dptr[0] * 31.75f * 1.01f; + float scale = dptr[0] * 1.00f; float row_av = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); const int64_t i = ii - (row*n_per_row)/QK_K; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5bb75d322..b1fa39afa 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1617,11 +1617,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref, .vec_dot = vec_dot_iq4_kt_q8_k, -#ifdef __ARM_NEON - .vec_dot_type = GGML_TYPE_F16, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else - .vec_dot_type = GGML_TYPE_F32, + .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif +//#ifdef __ARM_NEON +// .vec_dot_type = GGML_TYPE_F16, +//#else +// .vec_dot_type = GGML_TYPE_F32, +//#endif .nrows = 1, .row_meta_size = 8, }, diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index bc7bcf8b6..0a8d2d037 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -97,6 +97,43 @@ struct Trellis2 { } }; + +struct Trellis3 { + constexpr static uint32_t ka = 89226354; + constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t kb1 = kb*ka+kb; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t kb2 = kb1*ka+kb; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t kb3 = kb2*ka+kb; + const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); + const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); + const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + + inline __m256i next8(uint32_t val1, uint32_t val2) const { + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + } + inline __m256 gen8(uint32_t val1, uint32_t val2) const { + auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); + auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); + return _mm256_cvtepi32_ps(i8); + } + inline __m256i next32(const uint32_t * val) const { + __m256i aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); + aux[i] = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), i8); + } + aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + aux[0] = _mm256_packs_epi16(aux[0], aux[2]); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + return _mm256_permutevar8x32_epi32(aux[0], shuffle); + } +}; + void iqk_dequantize_iq2_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -315,19 +352,121 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +// Q8_0 repacking: +// for (int ib = 0; ib < nblock; ++ib) { +// for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; +// for (int l = 0; l < 4; ++l) { +// for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { +// y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; +// y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; +// as uint32_t +// y[ib].qs[8*l+k+ 0] = x8[k][ib].qs[l+ 0]; +// y[ib].qs[8*l+k+32] = x8[k][ib].qs[l+16]; +// } +// } +// } + +void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[16]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 2); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); + _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + //for (int k = 0; k < 8; ++k) { + // auto shb = x8[k][i].qs; + // const uint8_t * ql = (const uint8_t *)(shb + 8); + // const uint8_t * qh = ql + kNumGroups; + // for (int ib = 0; ib < 4; ++ib) { + // uint32_t offset1 = ((shb[ib+0] & 1) << 15) + 4096; + // uint32_t offset2 = ((shb[ib+4] & 1) << 15) + 4096; + // for (int j = 0; j < 4; ++j) { + // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + // idx[64*ib + 16*j + k ] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + // idx[64*ib + 16*j + k + 8] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + // idx[64*ib + 16*j + k + 256] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + // idx[64*ib + 16*j + k + 264] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + // //uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + // //uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + // //uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + // //uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + // //auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); + // //auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); + // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); + // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); + // } + // } + //} + //for (int j = 0; j < 64; ++j) { + // _mm256_storeu_si256((__m256i *)y[j/8].qs+(j%8), trellis.next32(idx+8*j)); + //} + //int shift1 = 8 - 4*(ib/4); + //for (int j = 0; j < 4; ++j) { + // for (int k = 0; k < 8; ++k) { + // const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + // const uint8_t * qh = ql + kNumGroups; + // const uint32_t sh = x8[k][i].qs[ib] >> (8 + 6*j); + // idx[k+0] = ql[8*ib+2*j+0] + ((qh[8*(ib%4)+2*j+0] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + // idx[k+8] = ql[8*ib+2*j+1] + ((qh[8*(ib%4)+2*j+1] << shift1) & 0xf00) + ((sh & 56) << 9) + idx0[k]; + // } + // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, trellis.next32(idx+0)); + // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, trellis.next32(idx+8)); + //} + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + _mm256_storeu_si256((__m256i *)y[ib].qs+j, trellis.next32(idx)); + } + } + y += 8; // = QK_K/32; + } + + } +} + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis3 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto d = _mm256_set1_ps(dptr[0]); auto dav = _mm256_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); @@ -349,8 +488,8 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - auto x_val1 = _mm256_fmadd_ps(scale1, trellis_gen8(trellis.next8(val1, val3)), dav); - auto x_val2 = _mm256_fmadd_ps(scale2, trellis_gen8(trellis.next8(val2, val4)), dav); + auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); + auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); _mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); @@ -370,7 +509,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf const int nb = n/QK_K; constexpr int kNumGroups = 64; - Trellis2 trellis; + Trellis3 trellis; union { __m256 vec; float val[8]; } s_helper; union { __m256i vec; uint32_t val[8]; } o_helper; @@ -389,7 +528,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto d = _mm256_set1_ps(dptr[0]); auto dav = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); @@ -413,8 +552,8 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); - auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); + auto x_val1 = _mm256_mul_ps(scale1, trellis.gen8(val1, val3)); + auto x_val2 = _mm256_mul_ps(scale2, trellis.gen8(val2, val4)); if constexpr (nrc_y == 1) { auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0); auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128); @@ -474,7 +613,7 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * switch (type) { case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; - case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float *)y, stride_y, nrc_x); break; + case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break; default: return false; } return true; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d04ad22a1..950cfb038 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -238,7 +238,7 @@ struct MulMat { switch (type) { case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; - case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #else @@ -352,9 +352,9 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, first_x *= num_rows; nrc_x *= num_rows; - auto type_size = ggml_type_size(dequant_type); + //auto type_size = ggml_type_size(dequant_type); - size_t row_size_qx = ne00*type_size; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); size_t row_size_qy = strideB; //printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index ee85344db..417cff72c 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7443,7 +7443,7 @@ class QuantizerIQKT { constexpr static int kNg = kBlockSize/kGroupSize; constexpr static int kNblock = kSuperBlockSize/kBlockSize; constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 - constexpr static float kScale = 31.75f; + constexpr static float kScale = 1.f; //31.75f; constexpr static bool kVerbose = false; QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); @@ -7456,15 +7456,19 @@ class QuantizerIQKT { static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; + //constexpr uint32_t kmask = 0x8fff8fff; + //constexpr uint32_t km32 = 0x3b603b60; uint32_t x = i + offset; + uint32_t s; + auto i8 = (const int8_t *)&s; for (int k = 0; k < kGroupSize; ++k) { x = ka*x + kb; - uint32_t s = (x & kmask) ^ km32; - float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); - if constexpr (is_abs) result[k] = scale*std::abs(val); - else result[k] = scale*val; + s = x & 0x3f3f3f3f; + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + //uint32_t s = (x & kmask) ^ km32; + //float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + //if constexpr (is_abs) result[k] = scale*std::abs(val); + //else result[k] = scale*val; } } @@ -8244,7 +8248,7 @@ size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) { assert(k % QuantizerIQ2KT::kSuperBlockSize == 0); #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ2_KT, k, x, 0, y, 0, 1)) return; #endif const int nb = k / QuantizerIQ2KT::kSuperBlockSize; const float * dptr = (const float *)x; @@ -8595,7 +8599,10 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f row_av += x[j]; amax_row = std::max(amax_row, std::abs(x[j])); } - row_av /= n_per_row; + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + //row_av /= n_per_row; + row_av = 0; + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! dptr[1] = row_av; if (!amax_row) { dptr[0] = 0.f; @@ -8628,7 +8635,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f continue; } float best = 0; - float scale_0 = std::max(92.f, 127.f*amax/amax_row); + float scale_0 = std::max(90.f, 124.f*amax/amax_row); for (int itry = -kNtry; itry <= kNtry; ++itry) { quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx); @@ -8759,7 +8766,7 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { #ifdef __AVX2__ - if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return; + //if (iqk_dequantize_ktquants(GGML_TYPE_IQ4_KT, k, x, 0, y, 0, 1)) return; #endif using Q = QuantizerIQ4KT; assert(k % Q::kSuperBlockSize == 0); From a2961344349cc846024ac4394935f45d0631f95c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 16:18:58 +0300 Subject: [PATCH 02/17] Something is not working with the AVX2 dot product --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 264 ++++++++++++++++++++++------- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 2 files changed, 208 insertions(+), 58 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0a8d2d037..5152f33dd 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -120,11 +120,13 @@ struct Trellis3 { auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); return _mm256_cvtepi32_ps(i8); } + template inline __m256i next32(const uint32_t * val) const { + const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); __m256i aux[4]; for (int i = 0; i < 4; ++i) { auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); - aux[i] = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), i8); + aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); } aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -352,20 +354,6 @@ void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } -// Q8_0 repacking: -// for (int ib = 0; ib < nblock; ++ib) { -// for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; -// for (int l = 0; l < 4; ++l) { -// for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { -// y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; -// y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; -// as uint32_t -// y[ib].qs[8*l+k+ 0] = x8[k][ib].qs[l+ 0]; -// y[ib].qs[8*l+k+32] = x8[k][ib].qs[l+16]; -// } -// } -// } - void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -397,46 +385,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); _mm_storeu_si128((__m128i *)y[ib].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); - //for (int k = 0; k < 8; ++k) { - // auto shb = x8[k][i].qs; - // const uint8_t * ql = (const uint8_t *)(shb + 8); - // const uint8_t * qh = ql + kNumGroups; - // for (int ib = 0; ib < 4; ++ib) { - // uint32_t offset1 = ((shb[ib+0] & 1) << 15) + 4096; - // uint32_t offset2 = ((shb[ib+4] & 1) << 15) + 4096; - // for (int j = 0; j < 4; ++j) { - // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - // idx[64*ib + 16*j + k ] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - // idx[64*ib + 16*j + k + 8] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - // idx[64*ib + 16*j + k + 256] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - // idx[64*ib + 16*j + k + 264] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - // //uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - // //uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - // //uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - // //uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - // //auto x_val1 = _mm256_fmadd_ps(scale1, trellis.gen8(val1, val3), dav); - // //auto x_val2 = _mm256_fmadd_ps(scale2, trellis.gen8(val2, val4), dav); - // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j, x_val1); - // //_mm256_storeu_ps(y + i*QK_K + 32*ib + 8*j + QK_K/2, x_val2); - // } - // } - //} - //for (int j = 0; j < 64; ++j) { - // _mm256_storeu_si256((__m256i *)y[j/8].qs+(j%8), trellis.next32(idx+8*j)); - //} - //int shift1 = 8 - 4*(ib/4); - //for (int j = 0; j < 4; ++j) { - // for (int k = 0; k < 8; ++k) { - // const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); - // const uint8_t * qh = ql + kNumGroups; - // const uint32_t sh = x8[k][i].qs[ib] >> (8 + 6*j); - // idx[k+0] = ql[8*ib+2*j+0] + ((qh[8*(ib%4)+2*j+0] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; - // idx[k+8] = ql[8*ib+2*j+1] + ((qh[8*(ib%4)+2*j+1] << shift1) & 0xf00) + ((sh & 56) << 9) + idx0[k]; - // } - // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+0, trellis.next32(idx+0)); - // _mm256_storeu_si256((__m256i *)y[ib].qs+2*j+1, trellis.next32(idx+8)); - //} int shift1 = 8 - 4*(ib/4); for (int j = 0; j < 8; ++j) { for (int k = 0; k < 8; ++k) { @@ -454,6 +402,92 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } +/* +template +void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + __m256i xv[8]; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[8]; + + union { float f; uint32_t u; } bf16_helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 2); + } + auto vd = _mm256_loadu_ps(dkt); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f)); + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + xv[j] = trellis.next32(idx); + } + for (int iy = 0; iy < nrc_y; ++iy) { + const auto& yb = y[iy][2*i+ib/4]; + int i4 = ib%4; + auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0); + auto vy = MM256_SET_M128I(vy8, vy8); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50)); + sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff)); + vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1); + vy = MM256_SET_M128I(vy8, vy8); + sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50)); + sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff)); + bf16_helper.u = yb.d[i4] << 16; + auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f)); + accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]); + bf16_helper.u = yb.d[i4+4] << 16; + accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, accd[iy]); + } + } +} +*/ + void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -503,6 +537,112 @@ void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t } } +template +void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + union { __m256i vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y; + + __m256 accd[k_acc]; + const block_q8_2_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); + } + + uint32_t values[64]; + __m256i xv[4], dot[4]; + __m256 scales[2]; + + auto sum_4 = [&dot] () { + // dot[k] has 8 values from block k + // 0 1 0 1 0 1 0 1 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[0], dot[1]), _mm256_unpackhi_epi32(dot[0], dot[1])); + // 2 3 2 3 2 3 2 3 + dot[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(dot[2], dot[3]), _mm256_unpackhi_epi32(dot[2], dot[3])); + // 0 1 2 3 0 1 2 3 + dot[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(dot[0], dot[2]), _mm256_unpackhi_epi64(dot[0], dot[2])); + return _mm256_cvtepi32_ps(dot[0]); + }; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = _mm256_loadu_si256((const __m256i *)y + k); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + } + }; + + auto m126 = _mm256_set1_ps(-126.f); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = _mm256_set1_ps(dptr[0]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + iscales = _mm256_sub_epi32(iscales, _mm256_set1_epi32(64)); + auto all_scales = _mm256_mul_ps(d, _mm256_cvtepi32_ps(iscales)); + auto scales_l = _mm256_castps256_ps128(all_scales); + auto scales_h = _mm256_extractf128_ps(all_scales, 1); + scales[0] = _mm256_set_m128(scales_l, scales_l); + scales[1] = _mm256_set_m128(scales_h, scales_h); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < 4; ++ib) { + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + } + } + // sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8) + // d4*d8*sum[x_i*y_i] - 126*d4*m8 + for (int i128 = 0; i128 < 2; ++i128) { + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + //auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16)); + //auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + //m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f)); + //for (int k = 0; k < 4; ++k) { + // xv[k] = trellis.next32(values + 32*i128 + 8*k); + // auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k); + // dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + //} + //accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]); + //accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_2_x4& yb = y[iy][2*i+i128]; + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); + dy = _mm256_mul_ps(scales[i128], dy); + auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); + auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + compute_dot(yb.qs); + accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); + accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + } +} + template void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -585,11 +725,21 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { - if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { + if (ne00%QK_K != 0) return false; + + func16 = nullptr; + + if (typeA == GGML_TYPE_IQ4_KT) { + if (typeB == GGML_TYPE_Q8_2_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels); + return true; + } return false; } - func16 = nullptr; + if (ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } switch (typeA) { case GGML_TYPE_IQ2_KT: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 950cfb038..f2876b1b7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -698,7 +698,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: - return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; + return iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: From 98f35bfaa2bf94426108e70534e4a45bf81f2edb Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 17:43:36 +0300 Subject: [PATCH 03/17] New iq4_kt: CUDA MMVQ --- ggml/src/ggml-cuda/common.cuh | 7 +++++ ggml/src/ggml-cuda/iqk_mmvq.cu | 46 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 5 ++++ ggml/src/ggml-cuda/mmvq.cu | 4 +++ 4 files changed, 62 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8f3d2a265..291378f42 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -648,6 +648,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 747af5a7a..843e835ec 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -395,6 +395,44 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1( *result += dl * __low2float(bq8_1[ib32].ds) * sumi; } +__device__ __forceinline__ void vec_dot_iq4_kt_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) { + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t km = 0x3f3f3f3f; + + float scale = *(const float *)vbq; + const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + 2*sizeof(float)) + kbx; + + // iqs is 0...28 + const int ib32 = iqs/4; // Why iqs/4 ? + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + //const int8_t * q8 = bq8_1[ib32].qs; + const int ls = (bq4->qs[ib32] & 0xff) >> 1; + const float dl = scale * (ls - 64); + const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096; + auto ql = (const uint8_t *)(bq4->qs + 8); + auto qh = ql + 64; + ql += 8*ib32; + qh += 8*(ib32%4); + const int shift1 = 8 - 4*(ib32/4); + int sumi = 0; + for (int j = 0; j < 8; ++j) { + const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j); + uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0; + int v4 = 0; + for (int k = 0; k < 4; ++k) { + val = ka*val + kb; + //int s = val & km; + //sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126); + v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + sumi = ggml_cuda_dp4a(v4, q8[j], sumi); + } + *result += dl * __low2float(bq8_1[ib32].ds) * sumi; +} + #define VDR_IQ4_KSS_Q8_1_MMVQ 4 #define VDR_IQ4_KSS_Q8_1_MMQ 4 @@ -1171,6 +1209,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda( iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 1e4257e85..4739a7fe8 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -95,3 +95,8 @@ void mul_mat_vec_iq1_s_r4_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); + +void mul_mat_vec_iq4_kt_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index cc00d278f..1ed70470d 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -526,6 +526,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm break; case GGML_TYPE_IQ4_KSS: mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + case GGML_TYPE_IQ4_KT: + mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; break; case GGML_TYPE_IQ2_KS: mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); @@ -683,6 +686,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ4_KT: return true; default: return false; From 9b4103ed548e0ce43b47e33312cd3c8ea86b6791 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 18:21:43 +0300 Subject: [PATCH 04/17] New iq4_kt: CUDA MMQ --- ggml/src/ggml-cuda/mmq.cu | 4 ++ ggml/src/ggml-cuda/mmq.cuh | 84 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_quantize.cpp | 11 ++++- 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index a13be11b1..21bf9003a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -100,6 +100,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_KS_R4: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ4_KT: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ5_KS: mul_mat_q_case(ctx, args, stream); break; @@ -172,6 +175,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ4_KT: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 608de8f00..7d829c7e6 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -93,6 +93,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_KS_R4: case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ4_KT: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -202,6 +203,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } } @@ -250,6 +252,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } } @@ -2790,6 +2793,79 @@ template static __device__ __forceinlin } +template static __device__ __forceinline__ void load_tiles_iq4_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + 2*sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto shb = bxi->qs; + const auto ql = (const uint8_t *)(shb + 8); + const auto qh = ql + 64; + const uint32_t sh = shb[ib32] >> (8 + 6*j); + uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); + uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12); + uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9); + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val1 = ka*val1 + kb; + val2 = ka*val2 + kb; + v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k; + v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 2) + kbx0; + const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -3382,6 +3458,13 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; @@ -3843,6 +3926,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 417cff72c..066002746 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7525,6 +7525,7 @@ QuantizerIQKT::QuantizerIQKT(int num_c set_values(i, data, kScale, offset); data += kGroupSize; } + if (num_clusters == 0) return; // Make 128 clusters. // Note: we get a slightly better result by using 64 clusters // at the expense of almost doubling the quantization time. @@ -8575,6 +8576,14 @@ const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { return *quantizer1; } +const QuantizerIQ4KT& iq4kt_dequantizer() { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unique_ptr dequantizer; + if (!dequantizer) dequantizer = std::make_unique(0, 0, 4096); + return *dequantizer; +} + void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { constexpr float kSigmaScale = 2.0f; @@ -8776,7 +8785,7 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { const float d = dptr[0] * Q::kScale; const float row_av = dptr[1]; x = (const block_iq4_kt *)(dptr + 2); - auto& deq = iq4kt_quantizer(); + auto& deq = iq4kt_dequantizer(); for (int ibl = 0; ibl < nb; ++ibl) { auto shb = x[ibl].qs; auto ql = (const uint8_t *)(shb + Q::kNblock); From fb776ab7ba5f9ec539c7b3a2a72d47fdee7f290b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 18:35:10 +0300 Subject: [PATCH 05/17] For now have only iq4_kt use the new trellis --- ggml/src/iqk/iqk_quantize.cpp | 60 +++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 066002746..db447d84e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -7432,7 +7432,7 @@ __m256 hsum_float_4x8(__m256 * accm) { return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); } #endif -template +template class QuantizerIQKT { static_assert(group_size == 8 || group_size == 4); static_assert(block_size >= 8 && block_size%8 == 0); @@ -7443,7 +7443,7 @@ class QuantizerIQKT { constexpr static int kNg = kBlockSize/kGroupSize; constexpr static int kNblock = kSuperBlockSize/kBlockSize; constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 - constexpr static float kScale = 1.f; //31.75f; + constexpr static float kScale = is_int ? 1.f : 31.75f; constexpr static bool kVerbose = false; QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); @@ -7456,19 +7456,25 @@ class QuantizerIQKT { static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; - //constexpr uint32_t kmask = 0x8fff8fff; - //constexpr uint32_t km32 = 0x3b603b60; uint32_t x = i + offset; - uint32_t s; - auto i8 = (const int8_t *)&s; - for (int k = 0; k < kGroupSize; ++k) { - x = ka*x + kb; - s = x & 0x3f3f3f3f; - result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); - //uint32_t s = (x & kmask) ^ km32; - //float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); - //if constexpr (is_abs) result[k] = scale*std::abs(val); - //else result[k] = scale*val; + if constexpr (is_int) { + uint32_t s; + auto i8 = (const int8_t *)&s; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + s = x & 0x3f3f3f3f; + result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f); + } + } else { + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + uint32_t s = (x & kmask) ^ km32; + float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + if constexpr (is_abs) result[k] = scale*std::abs(val); + else result[k] = scale*val; + } } } @@ -7517,8 +7523,8 @@ class QuantizerIQKT { float m_mid[4*kGroupSize]; }; -template -QuantizerIQKT::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { +template +QuantizerIQKT::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { m_values.resize(kNumVal*kGroupSize); float * data = m_values.data(); for (int i = 0; i < kNumVal; ++i) { @@ -7534,8 +7540,8 @@ QuantizerIQKT::QuantizerIQKT(int num_c m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values); } -template -std::pair QuantizerIQKT::find_best_scale( +template +std::pair QuantizerIQKT::find_best_scale( const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumq2 = 0; #ifdef __AVX2__ @@ -7567,8 +7573,8 @@ std::pair QuantizerIQKT: return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); } -template -float QuantizerIQKT::find_best_inverse_scale( +template +float QuantizerIQKT::find_best_inverse_scale( const float * xb, const float * weight, const int * best_idx) const { float sumqx = 0, sumx2 = 0; #ifdef __AVX2__ @@ -7600,8 +7606,8 @@ float QuantizerIQKT::find_best_inverse return sumx2 > 0 ? sumqx/sumx2 : 0.f; } -template -void QuantizerIQKT::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { +template +void QuantizerIQKT::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { if (!d) { std::memset(best_idx, 0, kNg*sizeof(int)); return; @@ -7779,8 +7785,8 @@ void QuantizerIQKT::find_best_match(fl #endif } -template -std::vector> QuantizerIQKT::finalize_clusters(int num_neighbours, +template +std::vector> QuantizerIQKT::finalize_clusters(int num_neighbours, const std::vector& values, const std::vector& clusters, std::vector>& c_values) { int ncluster = clusters.size()/kGroupSize; std::vector> p_in_cluster(ncluster); @@ -7866,8 +7872,8 @@ std::vector> QuantizerIQKT -std::vector QuantizerIQKT::cluster_points(const std::vector& points, int ncluster, int niter, float * mid) { +template +std::vector QuantizerIQKT::cluster_points(const std::vector& points, int ncluster, int niter, float * mid) { constexpr int ndim = kGroupSize; GGML_ASSERT(points.size() % ndim == 0); int npoint = points.size() / ndim; @@ -8561,7 +8567,7 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx namespace{ -using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>; +using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15, false, true>; const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { static std::mutex mutex; From 2b6acd58435e983836c2aefde5f2e4250eebd0da Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 18:51:02 +0300 Subject: [PATCH 06/17] Fix iq2_kt that got broken along the way --- ggml/src/ggml-cuda/convert.cu | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index cba8d66b9..905595a76 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -340,14 +340,25 @@ inline __device__ int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -float __device__ __forceinline__ trellis_next(uint32_t& val) { +int __device__ __forceinline__ trellis_next_int(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; val = ka*val + kb; - //return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, 0x82828282); return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126); } +float __device__ __forceinline__ trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + uint32_t s; + const half * h = (const half *)&s; + val = ka*val + kb; + s = (val & kmask) ^ km32; + return (float)(h[0]+h[1]); +} + template static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { @@ -363,7 +374,7 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 8*ib; const uint16_t * ql = (const uint16_t *)x[i].ql; uint32_t idx = ql[ib] + 4096; - const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; for (int j = 0; j < 8; ++j) { y[j] = dl * trellis_next(idx); } @@ -398,7 +409,6 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); float scale = dptr[0] * 1.00f; - float row_av = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); const int64_t i = ii - (row*n_per_row)/QK_K; @@ -419,8 +429,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int ls = ((shb[ib32] & 0xff) >> 1) - 64; const float dl = scale * ls; for (int j = 0; j < 4; ++j) { - y[j+0] = dl * trellis_next(idx1) + row_av; - y[j+4] = dl * trellis_next(idx2) + row_av; + y[j+0] = dl * trellis_next_int(idx1); + y[j+4] = dl * trellis_next_int(idx2); } } From 65e654a69df1be1bb68feaec40126ad7c728c1dc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 19:12:09 +0300 Subject: [PATCH 07/17] New iq4_kt: AVX2 dot product finally works We get 13.6 t/s vs 8.4 t/s with the f16 trellis and f32 arithmetic. Still somewhat slower than other quants, but no longer pathetic. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 111 ++--------------------------- 1 file changed, 7 insertions(+), 104 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 5152f33dd..d2774e2a0 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -402,92 +402,6 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, } } -/* -template -void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%QK_K == 0); - const int nb = n/QK_K; - constexpr int kNumGroups = 64; - - Trellis3 trellis; - - constexpr int k_acc = nrc_y; - - __m256 accd[k_acc]; - const block_q8_2_x4 * y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) { - y[iy] = (const block_q8_2_x4 *)info.src1_row(iy); - } - - __m256i xv[8]; - - const block_iq4_kt * x8[8]; - float dkt[8]; - int32_t ls[8]; - uint32_t idx0[8], idx[8]; - - union { float f; uint32_t u; } bf16_helper; - - for (int ix = 0; ix < nrc_x; ix += 8) { - for (int k = 0; k < 8; ++k) { - const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); - dkt[k] = dptr[0]; - x8[k] = (const block_iq4_kt *)(dptr + 2); - } - auto vd = _mm256_loadu_ps(dkt); - - for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - for (int ib = 0; ib < QK_K/32; ++ib) { - for (int k = 0; k < 8; ++k) { - ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; - idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; - } - auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i *)ls))); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-126.f)); - int shift1 = 8 - 4*(ib/4); - for (int j = 0; j < 8; ++j) { - for (int k = 0; k < 8; ++k) { - const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); - const uint8_t * qh = ql + kNumGroups; - const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); - idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; - } - xv[j] = trellis.next32(idx); - } - for (int iy = 0; iy < nrc_y; ++iy) { - const auto& yb = y[iy][2*i+ib/4]; - int i4 = ib%4; - auto vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+0); - auto vy = MM256_SET_M128I(vy8, vy8); - auto sumi = _mm256_setzero_si256(); - sumi = _mm256_dpbusd_epi32(sumi, xv[0], _mm256_shuffle_epi32(vy, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, xv[1], _mm256_shuffle_epi32(vy, 0x50)); - sumi = _mm256_dpbusd_epi32(sumi, xv[2], _mm256_shuffle_epi32(vy, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, xv[3], _mm256_shuffle_epi32(vy, 0xff)); - vy8 = _mm_loadu_si128((const __m128i *)yb.qs + 2*i4+1); - vy = MM256_SET_M128I(vy8, vy8); - sumi = _mm256_dpbusd_epi32(sumi, xv[4], _mm256_shuffle_epi32(vy, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, xv[5], _mm256_shuffle_epi32(vy, 0x50)); - sumi = _mm256_dpbusd_epi32(sumi, xv[6], _mm256_shuffle_epi32(vy, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, xv[7], _mm256_shuffle_epi32(vy, 0xff)); - bf16_helper.u = yb.d[i4] << 16; - auto d8 = _mm256_mul_ps(scales, _mm256_set1_ps(bf16_helper.f)); - accd[iy] = _mm256_fmadd_ps(d8, _mm256_cvtepi32_ps(sumi), accd[iy]); - bf16_helper.u = yb.d[i4+4] << 16; - accd[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(bf16_helper.f), accd[iy]); - } - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, accd[iy]); - } - } -} -*/ - void iqk_dequantize_iq4_kt(int n, const void * vx, size_t bx, float * y, size_t stride_y, int nrc_x) { GGML_ASSERT(n%QK_K == 0); const int nb = n/QK_K; @@ -573,11 +487,12 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& auto compute_dot = [&dot, &xv] (const int8_t * y) { for (int k = 0; k < 4; ++k) { auto yv = _mm256_loadu_si256((const __m256i *)y + k); - dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); + dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); } }; - auto m126 = _mm256_set1_ps(-126.f); + //auto m126 = _mm256_set1_ps(-126.f); for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); @@ -609,30 +524,18 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; } } - // sum[d4 * (x_i - 126) * d8 * y_i] => d4*d8*sum[x_i*y_i] - 126*d4*(d8*sum[y_i] -> m8) - // d4*d8*sum[x_i*y_i] - 126*d4*m8 for (int i128 = 0; i128 < 2; ++i128) { - for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); - //auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y[0][2*i+i128].d)), 16)); - //auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); - //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); - //m8 = _mm256_mul_ps(m8, _mm256_set1_ps(-126.f)); - //for (int k = 0; k < 4; ++k) { - // xv[k] = trellis.next32(values + 32*i128 + 8*k); - // auto yv = _mm256_loadu_si256((const __m256i *)y[0][2*i+i128].qs + k); - // dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); - //} - //accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], d8), sum_4(), accd[0]); - //accd[0] = _mm256_fmadd_ps(scales[i128], m8, accd[0]); + //for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); for (int iy = 0; iy < nrc_y; ++iy) { const block_q8_2_x4& yb = y[iy][2*i+i128]; auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16)); dy = _mm256_mul_ps(scales[i128], dy); auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)); - auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); + //auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1)); compute_dot(yb.qs); accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]); - accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); + //accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]); } } } From 9d7bf1c82167b98016af618319e7de523faa775b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Jun 2025 19:33:33 +0300 Subject: [PATCH 08/17] New iq4_kt: fix vanilla AVX2 --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index d2774e2a0..28939fffd 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -117,7 +117,12 @@ struct Trellis3 { } inline __m256 gen8(uint32_t val1, uint32_t val2) const { auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD auto i8 = _mm256_dpbusd_epi32(_mm256_set1_epi32(-126), _mm256_set1_epi32(0x01010101), v8); +#else + auto dot = _mm256_maddubs_epi16(v8, _mm256_set1_epi32(0x01010101)); + auto i8 = _mm256_add_epi32(_mm256_set1_epi32(-126), _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif return _mm256_cvtepi32_ps(i8); } template @@ -126,7 +131,12 @@ struct Trellis3 { __m256i aux[4]; for (int i = 0; i < 4; ++i) { auto i8 = _mm256_and_si256(next8(val[2*i+0], val[2*i+1]), _mm256_set1_epi32(0x3f3f3f3f)); +#ifdef HAVE_FANCY_SIMD aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), i8); +#else + auto dot = _mm256_maddubs_epi16(i8, _mm256_set1_epi32(0x01010101)); + aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1))); +#endif } aux[0] = _mm256_packs_epi32(aux[0], aux[1]); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 aux[2] = _mm256_packs_epi32(aux[2], aux[3]); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -487,8 +497,13 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& auto compute_dot = [&dot, &xv] (const int8_t * y) { for (int k = 0; k < 4; ++k) { auto yv = _mm256_loadu_si256((const __m256i *)y + k); +#ifdef HAVE_FANCY_SIMD //dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[k], yv); dot[k] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); +#else + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(xv[k], xv[k]), _mm256_sign_epi8(yv, xv[k])); + dot[k] = _mm256_madd_epi16(p, _mm256_set1_epi16(1)); +#endif } }; From 8cc8b1e7954181efabf3cbd05d1cb3f71961c281 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 08:31:56 +0300 Subject: [PATCH 09/17] New iq4_kt: NEON implementation We get very respectable PP-512 = 120 t/s. TG-128 is pathetic at 5.3 t/s, so 20+% slower than the f16 variant. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 207 ++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 2 files changed, 205 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 28939fffd..cdee7afc1 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1199,10 +1199,213 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } } +struct Trellis3 { + constexpr static uint32_t ka = 89226354; + constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t kb1 = kb*ka+kb; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t kb2 = kb1*ka+kb; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t kb3 = kb2*ka+kb; + const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3}; + const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3}; + const uint8x16_t shuffle = load_shuffle(); + + inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const { + uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)}; + result.val[0] = vmlaq_u32(mkb, mka, result.val[0]); + result.val[1] = vmlaq_u32(mkb, mka, result.val[1]); + return result; + } + //inline int8x16x2_t next32(const uint32_t * val) const { + // int8x16x4_t aux; + // int8x16x2_t result; + // for (int i = 0; i < 2; ++i) { + // auto i8 = next8(val[4*i+0], val[4*i+1]); + // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + // aux.val[0] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); + // aux.val[1] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); + // i8 = next8(val[4*i+2], val[4*i+3]); + // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + // aux.val[2] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); + // aux.val[3] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); + // result.val[i] = vqtbl4q_s8(aux, shuffle); + // } + // return result; + //} + // This works: + inline int8x16x2_t next32(const uint32_t * val) const { + uint16x8_t aux[4]; + for (int i = 0; i < 4; ++i) { + auto i8 = next8(val[2*i+0], val[2*i+1]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); + auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); + aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); + } + int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; + return result; + } + static uint8x16_t load_shuffle() { + static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + return vld1q_u8(k_shuffle); + } +}; + +void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + const block_iq4_kt * x8[8]; + float dkt[8]; + int32_t ls[8]; + uint32_t idx0[8], idx[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) { + const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); + dkt[k] = dptr[0]; + x8[k] = (const block_iq4_kt *)(dptr + 2); + } + auto vd = vld1q_f32_x2(dkt); + + for (int i = 0; i < nb; ++i) { + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 8; ++k) { + ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64; + idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096; + } + auto scales1 = vmulq_f32(vd.val[0], vcvtq_f32_s32(vld1q_s32(ls+0))); + auto scales2 = vmulq_f32(vd.val[1], vcvtq_f32_s32(vld1q_s32(ls+4))); + vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1)); + vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2)); + int shift1 = 8 - 4*(ib/4); + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j); + idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k]; + } + vst1q_s8_x2(y[ib].qs+32*j, trellis.next32(idx)); + } + } + y += 8; // = QK_K/32; + } + } +} + +template +void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis3 trellis; + + union { uint32x4x2_t vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y; + + float32x4_t accd[k_acc]; + + const block_q8_0_x4 * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); + } + + uint32_t values[64]; + int8x16x2_t xv[4]; + int32x4x4_t dot; + + auto compute_dot = [&dot, &xv] (const int8_t * y) { + for (int k = 0; k < 4; ++k) { + auto yv = vld1q_s8_x2(y + 32*k); + dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); + } + dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]); + dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]); + return vpaddq_s32(dot.val[0], dot.val[2]); + }; + + float32x4x2_t scales; + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + auto d = vdupq_n_f32(dptr[0]); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); + + for (int i = 0; i < nb; ++i) { + auto vshb = vld1q_u32_x2(x[i].qs); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + auto iscales1 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff)), 1)); + auto iscales2 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)), 1)); + iscales1 = vaddq_s32(iscales1, vdupq_n_s32(-64)); + iscales2 = vaddq_s32(iscales2, vdupq_n_s32(-64)); + scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(iscales1)); + scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(iscales2)); + o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); + o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); + for (int ib = 0; ib < 4; ++ib) { + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + } + } + for (int i128 = 0; i128 < 2; ++i128) { + for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& yb = y[iy][2*i+i128]; + auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d))); + //auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16)); + //dy = vmulq_f32(scales.val[i128], dy); + auto sumi = compute_dot(yb.qs); + accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi)); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } + } +} + } bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { + + if (ne00%QK_K != 0) return false; + + func16 = nullptr; + + if (ggml_type(typeA) == GGML_TYPE_IQ4_KT) { + if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) { + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_0_x4_T, kernels); + return true; + } + return false; + } + //if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) { // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); // func16 = nullptr; @@ -1213,8 +1416,6 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_F16 : type; case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; - case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type; + case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; default: break; } #endif From eed22154d13b152f6e44837cf789fd9776620723 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 08:43:02 +0300 Subject: [PATCH 10/17] New iq4_kt: slightly faster NEON --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index cdee7afc1..4e06b1b8f 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1237,17 +1237,32 @@ struct Trellis3 { // return result; //} // This works: + //inline int8x16x2_t next32(const uint32_t * val) const { + // uint16x8_t aux[4]; + // for (int i = 0; i < 4; ++i) { + // auto i8 = next8(val[2*i+0], val[2*i+1]); + // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + // auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); + // auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); + // aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); + // } + // int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; + // return result; + //} inline int8x16x2_t next32(const uint32_t * val) const { - uint16x8_t aux[4]; - for (int i = 0; i < 4; ++i) { - auto i8 = next8(val[2*i+0], val[2*i+1]); + int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; + for (int i = 0; i < 2; ++i) { + auto i8 = next8(val[4*i+0], val[4*i+1]); + i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); + i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); + auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + i8 = next8(val[4*i+2], val[4*i+3]); i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); - auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); - aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); + auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); + result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2)); } - int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; return result; } static uint8x16_t load_shuffle() { From b41f4718ad5eb6988784375e3faec7d54a789421 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 09:17:32 +0300 Subject: [PATCH 11/17] New iq4_kt: slightly faster NEON --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 48 ++++++++++++++++++------------ 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 4e06b1b8f..0a17fe838 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1330,7 +1330,7 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& union { uint32x4x2_t vec; uint32_t val[8]; } o_helper; - constexpr int k_acc = nrc_y; + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; float32x4_t accd[k_acc]; @@ -1339,11 +1339,11 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& y[iy] = (const block_q8_0_x4 *)info.src1_row(iy); } - uint32_t values[64]; - int8x16x2_t xv[4]; + uint32_t values[16]; + int8x16x2_t xv[8]; int32x4x4_t dot; - auto compute_dot = [&dot, &xv] (const int8_t * y) { + auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) { for (int k = 0; k < 4; ++k) { auto yv = vld1q_s8_x2(y + 32*k); dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]); @@ -1379,27 +1379,37 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int j = 0; j < 4; ++j) { const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; } + xv[ib+0] = trellis.next32(values+0); + xv[ib+4] = trellis.next32(values+8); } - for (int i128 = 0; i128 < 2; ++i128) { - for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k); - for (int iy = 0; iy < nrc_y; ++iy) { - const block_q8_0_x4& yb = y[iy][2*i+i128]; - auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d))); - //auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16)); - //dy = vmulq_f32(scales.val[i128], dy); - auto sumi = compute_dot(yb.qs); - accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi)); + for (int iy = 0; iy < nrc_y; ++iy) { + const block_q8_0_x4& ybl = y[iy][2*i+0]; + const block_q8_0_x4& ybh = y[iy][2*i+1]; + auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d))); + auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d))); + auto sumil = compute_dot(ybl.qs, xv+0); + auto sumih = compute_dot(ybh.qs, xv+4); + if constexpr (nrc_y == 1) { + accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil)); + accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih)); + } else { + accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil)); + accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih)); } } } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(accd[iy])); + if constexpr (nrc_y == 1) { + info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1]))); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(accd[iy])); + } } } } From d334cbf552d4fa47747fee764c43de40cc069e03 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 10:05:49 +0300 Subject: [PATCH 12/17] New iq4_kt: faster NEON We are now at 9.4 t/s, up from 6.6 t/s for the f16 trellis. --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 38 +++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0a17fe838..d4547d4a1 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1353,6 +1353,9 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& return vpaddq_s32(dot.val[0], dot.val[2]); }; + //int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}}; + int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}}; + float32x4x2_t scales; for (int ix = 0; ix < nrc_x; ++ix) { @@ -1376,14 +1379,33 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096)); for (int ib = 0; ib < 4; ++ib) { - for (int j = 0; j < 4; ++j) { - const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - } + auto vql1 = vmovl_u8(vld1_u8(ql+8*ib)); + auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32)); + auto vqh = vmovl_u8(vld1_u8(qh+8*ib)); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8))); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4))); + auto sh1_u32 = vdupq_n_u32(shb[ib+0]); + auto sh2_u32 = vdupq_n_u32(shb[ib+4]); + auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1]))); + auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1]))); + vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1)); + vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2)); + auto oh1 = vdupq_n_u32(o_helper.val[ib+0]); + auto oh2 = vdupq_n_u32(o_helper.val[ib+4]); + vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1)); + vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1)); + vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2)); + vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2)); + //auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts); + //auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts); + //for (int j = 0; j < 4; ++j) { + // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + // values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + // values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + // values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + // values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + //} xv[ib+0] = trellis.next32(values+0); xv[ib+4] = trellis.next32(values+8); } From 1e6ff8a78865ab97178ab3ef7e6b4b1977b54fd3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 10:29:40 +0300 Subject: [PATCH 13/17] Minor --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index d4547d4a1..ba35631bf 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1353,7 +1353,6 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& return vpaddq_s32(dot.val[0], dot.val[2]); }; - //int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}}; int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}}; float32x4x2_t scales; @@ -1396,16 +1395,6 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1)); vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2)); vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2)); - //auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts); - //auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts); - //for (int j = 0; j < 4; ++j) { - // const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); - // const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - // values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; - // values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; - // values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; - // values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - //} xv[ib+0] = trellis.next32(values+0); xv[ib+4] = trellis.next32(values+8); } @@ -1443,23 +1432,16 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array Date: Sun, 8 Jun 2025 11:47:55 +0300 Subject: [PATCH 14/17] New iq4_kt trellis: not working Metal implementation --- ggml/src/ggml-metal.metal | 42 ++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a05a890e2..f850d998a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6596,6 +6596,25 @@ void kernel_mul_mv_iq2_k_f32_impl( } } +struct Trellis3 { + constexpr constant static uint32_t kmask = 0x3f3f3f3f; + constexpr constant static uint32_t ka = 89226354; + constexpr constant static uint32_t kb = 64248484; + constexpr constant static uint32_t ka1 = ka*ka; + constexpr constant static uint32_t kb1 = kb*ka+kb; + constexpr constant static uint32_t ka2 = ka1*ka; + constexpr constant static uint32_t kb2 = kb1*ka+kb; + constexpr constant static uint32_t ka3 = ka2*ka; + constexpr constant static uint32_t kb3 = kb2*ka+kb; + static inline char4 gen4(uint32_t val) { + thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask}; + thread const int8_t * a8 = (thread const int8_t *)aux; + char4 result; + for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; + return result; + } +}; + struct Trellis { constexpr constant static uint32_t kmask1 = 0x8fff8fff; constexpr constant static uint32_t kmask2 = 0x3b603b60; @@ -8586,20 +8605,20 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread device const uint32_t * shb = x->qs; device const uint8_t * ql = (device const uint8_t *)(shb + 8); device const uint8_t * qh = ql + 64; - float scale = d * (((shb[ib32] & 0xff) >> 1) - 64); + const int ls = (shb[ib32] & 0xff) >> 1; + const float scale = d * (ls - 64); const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); - const int jj = ib32*8 + 4*(il%2); - ql += jj; - qh += jj%32; + ql += 8*ib32; + qh += 8*(ib32%4); uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12; - const int shift = 8 - 4*(jj/32); + const int shift = 8 - 4*(ib32/4); for (int i = 0; i < 4; ++i) { uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset; - auto v = (float4)Trellis::gen4(idx); - reg[i] = v * scale; + auto c4 = Trellis3::gen4(idx); + reg[i] = {scale*c4[0], scale*c4[1], scale*c4[2], scale*c4[3]}; } } @@ -8931,18 +8950,17 @@ struct DequantizerKT4 { using type4x4 = T4x4; DequantizerKT4(device const char * cx, short il = 0) : il(il) { device const float * dptr = (device const float *)cx; - d[0] = dptr[0] * 31.75f * 1.01f; - d[1] = dptr[1]; + d = dptr[0] * 1.01f; x = (device const Block *)(dptr + 2); } inline void convert(thread T4x4& t) const { float4x4 tmp; - dequantize_iq4_kt(x, il, d[0], tmp); + dequantize_iq4_kt(x, il, d, tmp); for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void convert(int64_t ind, thread T4x4& t) { float4x4 tmp; - dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp); + dequantize_iq4_kt(x + ind/nl, ind%nl, d, tmp); for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; } inline void next() { @@ -8951,7 +8969,7 @@ struct DequantizerKT4 { } device const Block * x; short il; - float d[2]; + float d; }; template From b30bfc13776d83a1c4b43d88318afaa7d824704b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 13:36:47 +0300 Subject: [PATCH 15/17] Remove the extra 4 bytes of row meta data that is no longer used --- ggml/src/ggml-cuda/convert.cu | 2 +- ggml/src/ggml-cuda/iqk_mmvq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 4 ++-- ggml/src/ggml.c | 7 +------ ggml/src/iqk/iqk_gemm_ktquants.cpp | 4 ++-- ggml/src/iqk/iqk_quantize.cpp | 19 +++++-------------- 6 files changed, 12 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 905595a76..24cb87988 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -409,7 +409,7 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst int64_t row = (QK_K * ii) / n_per_row; const float * dptr = (const float *)((const char *)vx + row * row_size); float scale = dptr[0] * 1.00f; - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); const int64_t i = ii - (row*n_per_row)/QK_K; constexpr int kNumGroups = 64; diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 843e835ec..7dbea3063 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -403,7 +403,7 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1( constexpr uint32_t km = 0x3f3f3f3f; float scale = *(const float *)vbq; - const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + 2*sizeof(float)) + kbx; + const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx; // iqs is 0...28 const int ib32 = iqs/4; // Why iqs/4 ? diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 7d829c7e6..26a7933ca 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2819,7 +2819,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + 2*sizeof(float)) + kbx0; + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0; int ib32 = kqsx/4; int j = kqsx%4; @@ -2855,7 +2855,7 @@ template static __device__ __forceinlin } const float * dptr = (const float *)(x + i*stride); - const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 2) + kbx0; + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0; const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; #ifdef INT8_MMA_AVAILABLE diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b1fa39afa..b23455261 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1622,13 +1622,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif -//#ifdef __ARM_NEON -// .vec_dot_type = GGML_TYPE_F16, -//#else -// .vec_dot_type = GGML_TYPE_F32, -//#endif .nrows = 1, - .row_meta_size = 8, + .row_meta_size = 4, }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index ba35631bf..9c87373d4 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -383,7 +383,7 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, for (int k = 0; k < 8; ++k) { const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); dkt[k] = dptr[0]; - x8[k] = (const block_iq4_kt *)(dptr + 2); + x8[k] = (const block_iq4_kt *)(dptr + 1); } auto vd = _mm256_loadu_ps(dkt); @@ -512,7 +512,7 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = _mm256_set1_ps(dptr[0]); - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index db447d84e..0fdffc95d 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -8600,7 +8600,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f float * dptr = (float *)vy; - block_iq4_kt * y = (block_iq4_kt *)(dptr + 2); + block_iq4_kt * y = (block_iq4_kt *)(dptr + 1); auto& quantizer1 = iq4kt_quantizer(); auto& quantizer2 = iq4kt_quantizer(true); @@ -8609,16 +8609,10 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); - float amax_row = 0, row_av = 0; + float amax_row = 0; for (int j = 0; j < n_per_row; ++j) { - row_av += x[j]; amax_row = std::max(amax_row, std::abs(x[j])); } - //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - //row_av /= n_per_row; - row_av = 0; - //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - dptr[1] = row_av; if (!amax_row) { dptr[0] = 0.f; std::memset(y, 0, nblock*sizeof(block_iq4_kt)); @@ -8641,7 +8635,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; float amax = 0; for (int j = 0; j < Q::kBlockSize; ++j) { - xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + xaux[j] = xbl[ib*Q::kBlockSize+j]; float ax = std::abs(xaux[j]); amax = std::max(amax, ax); } @@ -8721,7 +8715,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f for (int ib = 0; ib < Q::kNblock; ++ib) { auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1; const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; - for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j]; int ls = nearest_int(id*scales[ib]); ls = std::min(ls, 63); *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1); @@ -8789,8 +8783,7 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { const int nb = k / Q::kSuperBlockSize; const float * dptr = (const float *)x; const float d = dptr[0] * Q::kScale; - const float row_av = dptr[1]; - x = (const block_iq4_kt *)(dptr + 2); + x = (const block_iq4_kt *)(dptr + 1); auto& deq = iq4kt_dequantizer(); for (int ibl = 0; ibl < nb; ++ibl) { auto shb = x[ibl].qs; @@ -8798,14 +8791,12 @@ void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { auto qh = ql + kNumGroups; for (int ib = 0; ib < Q::kNblock; ++ib) { int offset = shb[ib] & 1 ? 32768 + 4096 : 4096; - //auto& deq = shb[ib] & 1 ? deq2 : deq1; int ls = int((shb[ib] & 0xff) >> 1) - 64; float sl = d * ls; for (int ig = 0; ig < Q::kNg; ++ig) { int jj = ib*Q::kNg+ig; uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12); deq.set_values(idx, y, sl, offset); - for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av; y += Q::kGroupSize; } } From 07efec0e4cbe647b6170d6a2e6539466fa82f012 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Jun 2025 13:44:20 +0300 Subject: [PATCH 16/17] Cleanup --- ggml/src/ggml-metal.metal | 2 +- ggml/src/iqk/iqk_gemm_ktquants.cpp | 36 ++---------------------------- 2 files changed, 3 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f850d998a..e1d474040 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -8951,7 +8951,7 @@ struct DequantizerKT4 { DequantizerKT4(device const char * cx, short il = 0) : il(il) { device const float * dptr = (device const float *)cx; d = dptr[0] * 1.01f; - x = (device const Block *)(dptr + 2); + x = (device const Block *)(dptr + 1); } inline void convert(thread T4x4& t) const { float4x4 tmp; diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 9c87373d4..965ecc2da 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1218,38 +1218,6 @@ struct Trellis3 { result.val[1] = vmlaq_u32(mkb, mka, result.val[1]); return result; } - //inline int8x16x2_t next32(const uint32_t * val) const { - // int8x16x4_t aux; - // int8x16x2_t result; - // for (int i = 0; i < 2; ++i) { - // auto i8 = next8(val[4*i+0], val[4*i+1]); - // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); - // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - // aux.val[0] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); - // aux.val[1] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); - // i8 = next8(val[4*i+2], val[4*i+3]); - // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); - // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - // aux.val[2] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]))); - // aux.val[3] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]))); - // result.val[i] = vqtbl4q_s8(aux, shuffle); - // } - // return result; - //} - // This works: - //inline int8x16x2_t next32(const uint32_t * val) const { - // uint16x8_t aux[4]; - // for (int i = 0; i < 4; ++i) { - // auto i8 = next8(val[2*i+0], val[2*i+1]); - // i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); - // i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); - // auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])); - // auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])); - // aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2)); - // } - // int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))}; - // return result; - //} inline int8x16x2_t next32(const uint32_t * val) const { int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; for (int i = 0; i < 2; ++i) { @@ -1290,7 +1258,7 @@ void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, for (int k = 0; k < 8; ++k) { const float * dptr = (const float *)((const char*)vx + (ix+k)*bx); dkt[k] = dptr[0]; - x8[k] = (const block_iq4_kt *)(dptr + 2); + x8[k] = (const block_iq4_kt *)(dptr + 1); } auto vd = vld1q_f32_x2(dkt); @@ -1360,7 +1328,7 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = vdupq_n_f32(dptr[0]); - const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1); for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0); From f59fe11764076c69baf1da8d699a4c3cca9b89f2 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 09:59:39 +0300 Subject: [PATCH 17/17] Adding forgottent file --- ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu new file mode 100644 index 000000000..4590f9193 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);