Skip to content

Commit dbf590d

Browse files
committed
Trellis Quants Gen1 activated
WIP WIP Testing Trellis quantization Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable. Testing Trellis quantization: 4-bit quantized block scales rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw. Testing Trellis quantization: playing with scales and generators iq2_kt: quantize / dequantize I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B). iq2_kt: CUDA dequantize so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs. WIP WIP WIP - try larger blocks With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster. iq2_kt - this is better Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower. iq2_kt - even better Re-quantize after determining block scales (at the epxense of much longer quantization time). iq2_kt: CUDA dot product Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU. iq2_kt: very slightly faster CUDA dot product iq2_kt: f16 CUDA dot product We arrive at 112 t/s. iq2_kt: faster f16 CUDA dot product We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance. iq2_kt: faster f16 CUDA dot product We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16. Minor Adding iq3_kt 3.125 bpw. So far does not look good on the PPL vs bpw plot. Forgotten change WIP WIP iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants. WIP iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892 iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v. iq3_kt WIP: speed up quantization Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX. iq3_kt speed up quantization Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds! iq3_kt: CUDA dot product iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406 PPL(LLaMA-2-7B, 4096) = 6.4179 iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 Adding iq4_kt - not competitive at this point WIP WIP iq4_kt: CUDA dot product iq4_kt: minor tweaks iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297 PPL(LLaMA-2-7B, 4096) = 6.3913 Ah, quantization is faster too. About 20% faster. iq3_kt: small improvements and faster quantization iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627 PPL(LLaMA-2-7B, 4096) = 6.3825 Quantization is faster too: ~200 seconds for LLaMA-3.1-8B on Ryzen-5975WX. iq3_kt: small progress WIP iq4_kt: go to 4.0 bpw 15 bits per group of 4, plus 8 bit scales ifor blocks of 32. This gives a slightly better PPL than iq4_kss. iq4_kt: very slightly better at the expense of much longer quantization time. iq4_kt: failed attemt to adjust CUDA dot product It was working for 4.125 bpw. But after changing to 4.0 bpw there is something wrong and I don't see the bug. DRY DRY iq4_kt: CUDA dot product works DRY Report actual bpw Minor tweaks Checkpoint Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude plus 1 bpw for the sign. It goves a visible improvement in the PPL vs bpw plot, but that comes at the expense of much longer quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX). I also notices that the 3INST generator is not actually generating a Gaussian distribution. But going to a better generator means readjusting all the hyper-parameters, so leaving it for later.
1 parent a3122c0 commit dbf590d

File tree

10 files changed

+1554
-58
lines changed

10 files changed

+1554
-58
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9712,8 +9712,8 @@ static void ggml_compute_forward_get_rows(
97129712
case GGML_TYPE_IQ2_K:
97139713
case GGML_TYPE_IQ2_KS:
97149714
case GGML_TYPE_IQ2_KT:
9715-
// case GGML_TYPE_IQ3_KT:
9716-
// case GGML_TYPE_IQ4_KT:
9715+
case GGML_TYPE_IQ3_KT:
9716+
case GGML_TYPE_IQ4_KT:
97179717
case GGML_TYPE_IQ3_K:
97189718
case GGML_TYPE_IQ4_K:
97199719
case GGML_TYPE_IQ5_K:

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -524,19 +524,19 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
524524
static constexpr int qi = QI4_XS;
525525
};
526526

527-
template<>
528-
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> {
529-
static constexpr int qk = QK_K;
530-
static constexpr int qr = QR4_XS;
531-
static constexpr int qi = QI4_XS;
532-
};
533-
534-
template<>
535-
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
536-
static constexpr int qk = QK_K;
537-
static constexpr int qr = QR4_XS;
538-
static constexpr int qi = QI4_XS;
539-
};
527+
// template<>
528+
// struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> {
529+
// static constexpr int qk = QK_K;
530+
// static constexpr int qr = QR4_XS;
531+
// static constexpr int qi = QI4_XS;
532+
// };
533+
534+
// template<>
535+
// struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
536+
// static constexpr int qk = QK_K;
537+
// static constexpr int qr = QR4_XS;
538+
// static constexpr int qi = QI4_XS;
539+
// };
540540

541541
template<>
542542
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {

ggml/src/ggml-cuda/convert.cu

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,11 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
398398
}
399399

400400
//template<typename dst_t>
401-
//static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, const int64_t nrows, const int64_t n_per_row) {
401+
402+
//static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, const int64_t nrows, const int64_t n_per_row) { next
403+
404+
//static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { orig
405+
402406
//
403407
// int64_t ii = blockIdx.x;
404408
// int64_t row = (QK_K * ii) / n_per_row;
@@ -1004,26 +1008,26 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
10041008
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
10051009
}
10061010

1007-
// template<typename dst_t>
1008-
// static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1009-
// const int64_t k = nrows * n_per_row;
1010-
// const int nb = k / QK_K;
1011-
// dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row));
1012-
// }
1011+
template<typename dst_t>
1012+
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1013+
const int64_t k = nrows * n_per_row;
1014+
const int nb = k / QK_K;
1015+
dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row));
1016+
}
10131017

1014-
// template<typename dst_t>
1015-
// static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1016-
// const int64_t k = nrows * n_per_row;
1017-
// const int nb = k / QK_K;
1018-
// dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row));
1019-
// }
1018+
template<typename dst_t>
1019+
static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1020+
const int64_t k = nrows * n_per_row;
1021+
const int nb = k / QK_K;
1022+
dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row));
1023+
}
10201024

1021-
// template<typename dst_t>
1022-
// static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1023-
// const int64_t k = nrows * n_per_row;
1024-
// const int nb = k / QK_K;
1025-
// dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row));
1026-
// }
1025+
template<typename dst_t>
1026+
static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1027+
const int64_t k = nrows * n_per_row;
1028+
const int nb = k / QK_K;
1029+
dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row));
1030+
}
10271031

10281032
template<typename dst_t>
10291033
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
@@ -1266,12 +1270,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
12661270
return dequantize_row_q6_K_cuda;
12671271
case GGML_TYPE_IQ2_XXS:
12681272
return dequantize_row_iq2_xxs_cuda;
1269-
// case GGML_TYPE_IQ2_KT:
1270-
// return dequantize_row_iq2_kt_cuda;
1271-
// case GGML_TYPE_IQ3_KT:
1272-
// return dequantize_row_iq3_kt_cuda;
1273-
// case GGML_TYPE_IQ4_KT:
1274-
// return dequantize_row_iq4_kt_cuda;
1273+
case GGML_TYPE_IQ2_KT:
1274+
return dequantize_row_iq2_kt_cuda;
1275+
case GGML_TYPE_IQ3_KT:
1276+
return dequantize_row_iq3_kt_cuda;
1277+
case GGML_TYPE_IQ4_KT:
1278+
return dequantize_row_iq4_kt_cuda;
12751279
case GGML_TYPE_IQ2_XS:
12761280
return dequantize_row_iq2_xs_cuda;
12771281
case GGML_TYPE_IQ2_S:
@@ -1343,12 +1347,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
13431347
return dequantize_row_q6_K_cuda;
13441348
case GGML_TYPE_IQ2_XXS:
13451349
return dequantize_row_iq2_xxs_cuda;
1346-
// case GGML_TYPE_IQ2_KT:
1347-
// return dequantize_row_iq2_kt_cuda;
1348-
// case GGML_TYPE_IQ3_KT:
1349-
// return dequantize_row_iq3_kt_cuda;
1350-
// case GGML_TYPE_IQ4_KT:
1351-
// return dequantize_row_iq4_kt_cuda;
1350+
case GGML_TYPE_IQ2_KT:
1351+
return dequantize_row_iq2_kt_cuda;
1352+
case GGML_TYPE_IQ3_KT:
1353+
return dequantize_row_iq3_kt_cuda;
1354+
case GGML_TYPE_IQ4_KT:
1355+
return dequantize_row_iq4_kt_cuda;
13521356
case GGML_TYPE_IQ2_XS:
13531357
return dequantize_row_iq2_xs_cuda;
13541358
case GGML_TYPE_IQ2_S:

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,11 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
542542
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
543543
}
544544

545+
__device__ __forceinline__ float vec_dot_iq2_kt_q8_1(
546+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
547+
return 0.f;
548+
}
549+
545550
#define VDR_IQ3_K_Q8_1_MMVQ 4
546551
#define VDR_IQ3_K_Q8_1_MMQ 4
547552

@@ -770,6 +775,13 @@ void mul_mat_vec_iq2_ks_q8_1_cuda(
770775
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
771776
}
772777

778+
void mul_mat_vec_iq2_kt_q8_1_cuda(
779+
const void * vx, const void * vy, float * dst,
780+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
781+
782+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
783+
}
784+
773785
void mul_mat_vec_iq5_k_q8_1_cuda(
774786
const void * vx, const void * vy, float * dst,
775787
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ void mul_mat_vec_iq2_ks_q8_1_cuda(
3232
const void * vx, const void * vy, float * dst,
3333
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
3434

35+
void mul_mat_vec_iq2_kt_q8_1_cuda(
36+
const void * vx, const void * vy, float * dst,
37+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
38+
3539
void mul_mat_vec_iq1_bn_q8_1_cuda(
3640
const void * vx, const void * vy, float * dst,
3741
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ void ggml_cuda_op_mul_mat_vec_q(
446446
case GGML_TYPE_IQ2_KS:
447447
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
448448
break;
449+
//case GGML_TYPE_IQ2_KT:
450+
// mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
451+
// break;
449452
case GGML_TYPE_IQ5_K:
450453
mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
451454
break;
@@ -466,3 +469,42 @@ void ggml_cuda_op_mul_mat_vec_q(
466469
GGML_UNUSED(src1_ncols);
467470
GGML_UNUSED(src1_padded_row_size);
468471
}
472+
473+
bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
474+
switch (src0_type) {
475+
case GGML_TYPE_Q4_0:
476+
case GGML_TYPE_Q4_1:
477+
case GGML_TYPE_Q5_0:
478+
case GGML_TYPE_Q5_1:
479+
case GGML_TYPE_Q6_0:
480+
case GGML_TYPE_Q8_0:
481+
case GGML_TYPE_Q2_K:
482+
case GGML_TYPE_Q3_K:
483+
case GGML_TYPE_Q4_K:
484+
case GGML_TYPE_Q5_K:
485+
case GGML_TYPE_Q6_K:
486+
case GGML_TYPE_IQ2_XXS:
487+
case GGML_TYPE_IQ2_XS:
488+
case GGML_TYPE_IQ2_S:
489+
case GGML_TYPE_IQ3_XXS:
490+
case GGML_TYPE_IQ1_S:
491+
case GGML_TYPE_IQ1_M:
492+
case GGML_TYPE_IQ1_BN:
493+
case GGML_TYPE_IQ2_BN:
494+
case GGML_TYPE_IQ4_NL:
495+
case GGML_TYPE_IQ4_XS:
496+
case GGML_TYPE_IQ2_K:
497+
case GGML_TYPE_IQ3_K:
498+
case GGML_TYPE_IQ4_K:
499+
case GGML_TYPE_IQ4_KS:
500+
case GGML_TYPE_IQ4_KSS:
501+
case GGML_TYPE_IQ2_KS:
502+
//case GGML_TYPE_IQ2_KT:
503+
case GGML_TYPE_IQ5_K:
504+
case GGML_TYPE_IQ6_K:
505+
case GGML_TYPE_IQ3_S:
506+
return true;
507+
default:
508+
return false;
509+
}
510+
}

ggml/src/ggml-cuda/mmvq.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec_q(
77
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
88
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
99
const int64_t src1_padded_row_size, cudaStream_t stream);
10+
11+
bool ggml_cuda_mmvq_type_supported(ggml_type src0_type);

ggml/src/ggml-quants.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5462,9 +5462,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
54625462
// case GGML_TYPE_IQ2_BN: break;
54635463
// case GGML_TYPE_IQ2_K: break;
54645464
case GGML_TYPE_IQ2_KS: break;
5465-
// case GGML_TYPE_IQ2_KT: break;
5466-
// case GGML_TYPE_IQ3_KT: break;
5467-
// case GGML_TYPE_IQ4_KT: break;
5465+
case GGML_TYPE_IQ2_KT: break;
5466+
case GGML_TYPE_IQ3_KT: break;
5467+
case GGML_TYPE_IQ4_KT: break;
54685468
// case GGML_TYPE_IQ3_K: break;
54695469
// case GGML_TYPE_IQ4_K: break;
54705470
// case GGML_TYPE_IQ5_K: break;

ggml/src/ggml.c

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -971,9 +971,9 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
971971
.blck_size = QK_K,
972972
.type_size = sizeof(block_iq2_kt),
973973
.is_quantized = true,
974-
// .to_float = (ggml_to_float_t) dequantize_row_iq2_kt,
974+
.to_float = (ggml_to_float_t) dequantize_row_iq2_kt,
975975
// .from_float = quantize_row_iq2_kt,
976-
// .from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref,
976+
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref,
977977
// .vec_dot = vec_dot_iq2_kt_q8_k,
978978
// .vec_dot_type = GGML_TYPE_Q8_K,
979979
// .nrows = 1,
@@ -984,9 +984,9 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
984984
.blck_size = QK_K,
985985
.type_size = sizeof(block_iq3_kt),
986986
.is_quantized = true,
987-
// .to_float = (ggml_to_float_t) dequantize_row_iq3_kt,
987+
.to_float = (ggml_to_float_t) dequantize_row_iq3_kt,
988988
// .from_float = quantize_row_iq3_kt,
989-
// .from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref,
989+
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref,
990990
// .vec_dot = vec_dot_iq3_kt_q8_k,
991991
// .vec_dot_type = GGML_TYPE_Q8_K,
992992
// .nrows = 1,
@@ -997,9 +997,9 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
997997
.blck_size = QK_K,
998998
.type_size = sizeof(block_iq4_kt),
999999
.is_quantized = true,
1000-
// .to_float = (ggml_to_float_t) dequantize_row_iq4_kt,
1000+
.to_float = (ggml_to_float_t) dequantize_row_iq4_kt,
10011001
// .from_float = quantize_row_iq4_kt,
1002-
// .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
1002+
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
10031003
// .vec_dot = vec_dot_iq4_kt_q8_k,
10041004
// .vec_dot_type = GGML_TYPE_Q8_K,
10051005
// .nrows = 1,
@@ -6873,9 +6873,9 @@ size_t ggml_quantize_chunk(
68736873
case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
68746874
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
68756875
case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6876-
// case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6877-
// case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6878-
// case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6876+
case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6877+
case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
6878+
case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
68796879
case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
68806880
case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
68816881
case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

0 commit comments

Comments
 (0)