Skip to content

Commit 0298c10

Browse files
committed
PreMerge IK quants and Treillis
Treillis WIP WIP WIP Testing Trellis quantization Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable. Testing Trellis quantization: 4-bit quantized block scales rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw. Testing Trellis quantization: playing with scales and generators iq2_kt: quantize / dequantize I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B). iq2_kt: CUDA dequantize so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs. WIP WIP WIP - try larger blocks With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster. iq2_kt - this is better Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower. iq2_kt - even better Re-quantize after determining block scales (at the epxense of much longer quantization time). iq2_kt: CUDA dot product Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU. iq2_kt: very slightly faster CUDA dot product iq2_kt: f16 CUDA dot product We arrive at 112 t/s. iq2_kt: faster f16 CUDA dot product We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance. iq2_kt: faster f16 CUDA dot product We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16. Minor Adding iq3_kt 3.125 bpw. So far does not look good on the PPL vs bpw plot. Forgotten change WIP WIP iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants. WIP iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892 iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v. iq3_kt WIP: speed up quantization Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX. iq3_kt speed up quantization Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds! iq3_kt: CUDA dot product iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406 PPL(LLaMA-2-7B, 4096) = 6.4179 iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 Adding iq4_kt - not competitive at this point WIP WIP iq4_kt: CUDA dot product iq4_kt: minor tweaks iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297 PPL(LLaMA-2-7B, 4096) = 6.3913 Ah, quantization is faster too. About 20% faster. iq3_kt: small improvements and faster quantization iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627 PPL(LLaMA-2-7B, 4096) = 6.3825 Quantization is faster too: ~200 seconds for LLaMA-3.1-8B on Ryzen-5975WX. iq3_kt: small progress WIP iq4_kt: go to 4.0 bpw 15 bits per group of 4, plus 8 bit scales ifor blocks of 32. This gives a slightly better PPL than iq4_kss. iq4_kt: very slightly better at the expense of much longer quantization time. iq4_kt: failed attemt to adjust CUDA dot product It was working for 4.125 bpw. But after changing to 4.0 bpw there is something wrong and I don't see the bug. DRY DRY iq4_kt: CUDA dot product works DRY Report actual bpw
1 parent 3d77cd8 commit 0298c10

File tree

18 files changed

+3748
-16
lines changed

18 files changed

+3748
-16
lines changed

examples/quantize/quantize.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
3333
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
3434
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
3535
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
36+
{ "IQ3_KT", LLAMA_FTYPE_MOSTLY_IQ3_KT, " 3.125 bpw trellis quantization", },
37+
{ "IQ4_KT", LLAMA_FTYPE_MOSTLY_IQ4_KT, " 4.0 bpw trellis quantization", },
3638
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
3739
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
3840
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
@@ -50,6 +52,16 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
5052
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", },
5153
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", },
5254
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", },
55+
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
56+
{ "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", },
57+
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
58+
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
59+
{ "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", },
60+
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
61+
{ "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",},
62+
{ "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", },
63+
{ "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", },
64+
{ "IQ6_K", LLAMA_FTYPE_MOSTLY_IQ6_K, " 6.6 bpw non-linear quantization", },
5365
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
5466
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
5567
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },

ggml/include/ggml.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,22 @@ extern "C" {
397397
GGML_TYPE_TQ2_0 = 35,
398398
//
399399
GGML_TYPE_Q6_0 = 133,
400+
GGML_TYPE_IQ1_BN = 134,
401+
GGML_TYPE_IQ2_BN = 135,
402+
GGML_TYPE_Q8_K64 = 136,
403+
GGML_TYPE_IQ2_K = 137,
404+
GGML_TYPE_IQ3_K = 138,
405+
GGML_TYPE_IQ4_K = 139,
406+
GGML_TYPE_IQ5_K = 140,
407+
GGML_TYPE_IQ6_K = 141,
408+
// depricated: GGML_TYPE_IQ2_TN = 142,
409+
// depricated: GGML_TYPE_IQ1_TN = 143,
410+
GGML_TYPE_IQ4_KS = 144,
411+
GGML_TYPE_IQ2_KS = 145,
412+
GGML_TYPE_IQ4_KSS = 146,
413+
GGML_TYPE_IQ2_KT = 147,
414+
GGML_TYPE_IQ3_KT = 148,
415+
GGML_TYPE_IQ4_KT = 149,
400416
GGML_TYPE_COUNT,
401417
};
402418

@@ -443,6 +459,21 @@ extern "C" {
443459
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
444460
//
445461
GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors
462+
GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors
463+
GGML_FTYPE_MOSTLY_IQ2_BN = 129, // except 1d tensors
464+
GGML_FTYPE_MOSTLY_IQ2_K = 130, // except 1d tensors
465+
GGML_FTYPE_MOSTLY_IQ3_K = 131, // except 1d tensors
466+
GGML_FTYPE_MOSTLY_IQ4_K = 132, // except 1d tensors
467+
GGML_FTYPE_MOSTLY_IQ5_K = 133, // except 1d tensors
468+
GGML_FTYPE_MOSTLY_IQ6_K = 134, // except 1d tensors
469+
// depricated: GGML_FTYPE_MOSTLY_IQ2_TN = 135, // except 1d tensors
470+
// depricated: GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors
471+
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
472+
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
473+
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
474+
GGML_FTYPE_MOSTLY_IQ2_KT = 140, // except 1d tensors
475+
GGML_FTYPE_MOSTLY_IQ3_KT = 141, // except 1d tensors
476+
GGML_FTYPE_MOSTLY_IQ4_KT = 142, // except 1d tensors
446477
};
447478

448479
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,89 @@ typedef struct {
429429
} block_iq4_xs;
430430
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
431431

432+
typedef struct {
433+
uint8_t scales[QK_K/32];
434+
uint8_t qs[QK_K/2];
435+
} block_iq4_ks;
436+
static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding");
437+
438+
typedef struct {
439+
uint32_t qs[QK_K/8];
440+
} block_iq4_kss;
441+
static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding");
442+
443+
typedef struct {
444+
ggml_half d;
445+
uint16_t extra;
446+
uint8_t scales[QK_K/32];
447+
uint8_t qs[QK_K/4];
448+
} block_iq2_k;
449+
static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding");
450+
451+
typedef struct {
452+
uint16_t extra;
453+
uint8_t scales[QK_K/64];
454+
uint8_t qs[QK_K/4];
455+
} block_iq2_ks;
456+
static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding");
457+
458+
typedef struct {
459+
uint8_t scales[QK_K/64];
460+
uint8_t ql[QK_K/4];
461+
} block_iq2_kt;
462+
static_assert(sizeof(block_iq2_kt) == QK_K/4 + QK_K/64, "wrong iq2_kt block size/padding");
463+
464+
typedef struct {
465+
uint8_t scales[QK_K/64];
466+
uint8_t ql[QK_K/4];
467+
uint8_t qh[QK_K/8];
468+
} block_iq3_kt;
469+
static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding");
470+
471+
typedef struct {
472+
uint32_t qs[QK_K/8];
473+
} block_iq4_kt;
474+
static_assert(sizeof(block_iq4_kt) == QK_K/2, "wrong iq4_kt block size/padding");
475+
476+
typedef struct {
477+
ggml_half d;
478+
uint16_t extra;
479+
uint16_t scales_h;
480+
uint8_t scales_l[QK_K/32];
481+
uint8_t qs[QK_K/4];
482+
uint8_t qh[QK_K/8];
483+
} block_iq3_k;
484+
static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding");
485+
486+
typedef struct {
487+
ggml_half d;
488+
uint16_t extra;
489+
uint8_t scales_h[QK_K/64];
490+
uint8_t scales_l[QK_K/32];
491+
uint8_t qs[QK_K/2];
492+
} block_iq4_k;
493+
static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding");
494+
495+
typedef struct {
496+
ggml_half d;
497+
uint16_t extra;
498+
uint8_t scales_h[QK_K/64];
499+
uint8_t scales_l[QK_K/32];
500+
uint8_t qs[QK_K/2];
501+
uint8_t qh[QK_K/8];
502+
} block_iq5_k;
503+
static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding");
504+
505+
typedef struct {
506+
ggml_half d;
507+
uint16_t extra;
508+
int8_t scales[QK_K/16];
509+
uint8_t qs[QK_K/2];
510+
uint8_t qh[QK_K/4];
511+
} block_iq6_k;
512+
static_assert(sizeof(block_iq6_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/4 + QK_K/16, "wrong iq6_k block size/padding");
513+
514+
432515
#endif // GGML_COMMON_DECL
433516
#endif // GGML_COMMON_DECL
434517

0 commit comments

Comments
 (0)