Skip to content

Commit 105261d

Browse files
committed
Few more changes and tweaks
1 parent 3f8d7a2 commit 105261d

File tree

2 files changed

+114
-45
lines changed

2 files changed

+114
-45
lines changed

ggml/src/ggml-common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ typedef struct {
378378
} block_iq3_xxs;
379379
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
380380

381-
// 3.4375 bpw
382381
#define IQ3S_N_SCALE QK_K/64
382+
// 3.4375 bpw
383383
typedef struct {
384384
ggml_half d;
385385
uint8_t qs[QK_K/4];

src/llama-quant.cpp

Lines changed: 113 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,22 @@ struct quantize_state_impl {
8484
int n_ffn_down = 0;
8585
int n_ffn_gate = 0;
8686
int n_ffn_up = 0;
87+
int n_ffn_down_exp = 0;
88+
int n_ffn_gate_exp = 0;
89+
int n_ffn_up_exp = 0;
90+
int n_ffn_down_shexp = 0;
91+
int n_ffn_gate_shexp = 0;
92+
int n_ffn_up_shexp = 0;
8793
int i_attention_wv = 0;
8894
int i_ffn_down = 0;
8995
int i_ffn_gate = 0;
9096
int i_ffn_up = 0;
97+
int i_ffn_down_exp = 0;
98+
int i_ffn_gate_exp = 0;
99+
int i_ffn_up_exp = 0;
100+
int i_ffn_down_shexp = 0;
101+
int i_ffn_gate_shexp = 0;
102+
int i_ffn_up_shexp = 0;
91103

92104
int n_k_quantized = 0;
93105
int n_fallback = 0;
@@ -175,6 +187,23 @@ static void llama_tensor_dequantize_impl(
175187
workers.clear();
176188
}
177189

190+
// Check if ftype is specifically IQ2_S or IQ2_M
191+
static inline bool is_iq2s_or_iq2m(llama_ftype ftype) {
192+
return ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M;
193+
}
194+
195+
// Check if ftype belongs to the IQ1 group
196+
static inline bool is_iq1_group(llama_ftype ftype) {
197+
return ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M;
198+
}
199+
200+
// Returns the appropriate type for expert _exps tensors based on ftype
201+
static inline ggml_type get_expert_exps_type(llama_ftype ftype) {
202+
if (is_iq1_group(ftype)) return GGML_TYPE_IQ2_XXS;
203+
if (is_iq2s_or_iq2m(ftype)) return GGML_TYPE_IQ3_XXS;
204+
/* otherwise */ return GGML_TYPE_IQ2_XS;
205+
}
206+
178207
static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
179208
const std::string name = ggml_get_name(tensor);
180209

@@ -242,7 +271,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
242271
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
243272
new_type = GGML_TYPE_Q2_K;
244273
}
245-
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
274+
else if (is_iq2s_or_iq2m(ftype)) {
246275
new_type = GGML_TYPE_IQ3_S;
247276
}
248277
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
@@ -256,7 +285,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
256285
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
257286
if (name.find("attn_v.weight") != std::string::npos) {
258287
if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
259-
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
288+
else new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
260289
++qs.i_attention_wv;
261290
}
262291
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k.weight") != std::string::npos) {
@@ -266,59 +295,95 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
266295
new_type = GGML_TYPE_Q4_K;
267296
}
268297
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_b.weight") != std::string::npos) {
269-
if (qs.i_attention_wv < qs.n_attention_wv/16) {
298+
if (qs.i_attention_wv < qs.n_attention_wv/8) {
270299
new_type = GGML_TYPE_Q4_K;
271300
}
272301
else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
273-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
302+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
274303
}
275304
++qs.i_attention_wv;
276305
}
277306
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_a.weight") != std::string::npos) {
278307
new_type = GGML_TYPE_Q4_K;
279308
}
280309
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_b.weight") != std::string::npos) {
281-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
310+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
282311
}
283-
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down") != std::string::npos) {
312+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down.weight") != std::string::npos) {
284313
if (qs.i_ffn_down < qs.n_ffn_down/16) {
285314
new_type = GGML_TYPE_Q4_K;
286315
}
287316
else if (qs.i_ffn_down < qs.n_ffn_down/8) {
288-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
317+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
289318
}
290319
++qs.i_ffn_down;
291320
}
292-
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate") != std::string::npos) {
321+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate.weight") != std::string::npos) {
293322
if (qs.i_ffn_gate < qs.n_ffn_gate/16) {
294323
new_type = GGML_TYPE_Q4_K;
295324
}
296-
else if (qs.i_ffn_gate < qs.n_ffn_gate/8 || qs.i_ffn_gate >= 7*qs.n_ffn_gate/8) {
297-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
325+
else if (qs.i_ffn_gate < qs.n_ffn_gate/8) {
326+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
298327
}
299328
++qs.i_ffn_gate;
300329
}
301-
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up") != std::string::npos) {
330+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up.weight") != std::string::npos) {
302331
if (qs.i_ffn_up < qs.n_ffn_up/16) {
303332
new_type = GGML_TYPE_Q4_K;
304333
}
305334
else if (qs.i_ffn_up < qs.n_ffn_up/8) {
306-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
335+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
307336
}
308337
++qs.i_ffn_up;
309338
}
339+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down_exps.weight") != std::string::npos) {
340+
if (qs.i_ffn_down_exp < qs.n_ffn_down_exp/8) {
341+
new_type = get_expert_exps_type(ftype);
342+
}
343+
++qs.i_ffn_down_exp;
344+
}
345+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate_exps.weight") != std::string::npos) {
346+
if (qs.i_ffn_gate_exp < qs.n_ffn_gate_exp/8) {
347+
new_type = get_expert_exps_type(ftype);
348+
}
349+
++qs.i_ffn_gate_exp;
350+
}
351+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up_exps.weight") != std::string::npos) {
352+
if (qs.i_ffn_up_exp < qs.n_ffn_up_exp/8) {
353+
new_type = get_expert_exps_type(ftype);
354+
}
355+
++qs.i_ffn_up_exp;
356+
}
357+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down_shexp.weight") != std::string::npos) {
358+
if (use_more_bits(qs.i_ffn_down_shexp, qs.n_ffn_down_shexp)) {
359+
new_type = GGML_TYPE_Q4_K;
360+
}
361+
++qs.i_ffn_down_shexp;
362+
}
363+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate_shexp.weight") != std::string::npos) {
364+
if (use_more_bits(qs.i_ffn_gate_shexp, qs.n_ffn_gate_shexp)) {
365+
new_type = GGML_TYPE_Q4_K;
366+
}
367+
++qs.i_ffn_gate_shexp;
368+
}
369+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up_shexp.weight") != std::string::npos) {
370+
if (use_more_bits(qs.i_ffn_up_shexp, qs.n_ffn_up_shexp)) {
371+
new_type = GGML_TYPE_Q4_K;
372+
}
373+
++qs.i_ffn_up_shexp;
374+
}
310375
else if (name.find("ffn_down") != std::string::npos) {
311376
if (qs.i_ffn_down < qs.n_ffn_down/8) {
312-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
377+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
313378
}
314379
++qs.i_ffn_down;
315380
}
316381
else if (name.find("attn_output.weight") != std::string::npos) {
317382
if (qs.model.hparams.n_expert >= 8) {
318-
new_type = GGML_TYPE_Q5_K;
383+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
319384
} else {
320-
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
321-
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
385+
if (is_iq1_group(ftype)) new_type = GGML_TYPE_IQ2_XXS;
386+
else if (is_iq2s_or_iq2m(ftype)) new_type = GGML_TYPE_IQ3_S;
322387
}
323388
}
324389
} else if (name.find("attn_v.weight") != std::string::npos) {
@@ -465,38 +530,28 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
465530
new_type = GGML_TYPE_IQ3_XXS;
466531
}
467532
++qs.i_ffn_up;
468-
} else if (name.find("attn_kv_a_mqa") != std::string::npos) {
469-
if (qs.model.hparams.n_expert >= 8) {
533+
} else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_a_mqa.weight") != std::string::npos) {
534+
new_type = GGML_TYPE_Q8_0;
535+
} else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_b.weight") != std::string::npos) {
536+
new_type = GGML_TYPE_Q4_K;
537+
if (qs.i_attention_wv < qs.n_attention_wv/16) {
470538
new_type = GGML_TYPE_Q8_0;
539+
} else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
540+
new_type = GGML_TYPE_Q6_K;
471541
}
472-
} else if (name.find("attn_kv_b.weight") != std::string::npos) {
473-
if (qs.model.hparams.n_expert >= 8) {
474-
new_type = GGML_TYPE_Q4_K;
475-
if (qs.i_attention_wv < qs.n_attention_wv/16) {
476-
new_type = GGML_TYPE_Q8_0;
477-
} else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
478-
new_type = GGML_TYPE_Q6_K;
479-
}
480-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
481-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
482-
}
542+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
543+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
483544
++qs.i_attention_wv;
484-
} else if (name.find("attn_q_b.weight") != std::string::npos) {
485-
if (qs.model.hparams.n_expert >= 8) {
486-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
487-
new_type = GGML_TYPE_Q4_K;
488-
}
489-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
490-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
491-
}
492-
} else if (name.find("attn_q_a.weight") != std::string::npos) {
493-
if (qs.model.hparams.n_expert >= 8) {
494-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
495-
new_type = GGML_TYPE_Q4_K;
496-
}
497-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
498-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
545+
} else if (qs.model.hparams.n_expert >= 8 &&name.find("attn_q_b.weight") != std::string::npos) {
546+
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
547+
new_type = GGML_TYPE_Q4_K;
499548
}
549+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q5_K;
550+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
551+
} else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_a.weight") != std::string::npos) {
552+
new_type = GGML_TYPE_Q5_K;
553+
if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q6_K;
554+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q8_0;
500555
}
501556

502557
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
@@ -793,11 +848,25 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
793848
++qs.n_attention_wv;
794849
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
795850
qs.has_output = true;
851+
} else if (name.find("ffn_gate_exps.weight") != std::string::npos) {
852+
++qs.n_ffn_gate_exp;
853+
} else if (name.find("ffn_gate_shexp.weight") != std::string::npos) {
854+
++qs.n_ffn_gate_shexp;
855+
} else if (name.find("ffn_down_exps.weight") != std::string::npos) {
856+
++qs.n_ffn_down_exp;
857+
} else if (name.find("ffn_down_shexp.weight") != std::string::npos) {
858+
++qs.n_ffn_down_shexp;
859+
} else if (name.find("ffn_up_exps.weight") != std::string::npos) {
860+
++qs.n_ffn_up_exp;
861+
} else if (name.find("ffn_up_shexp.weight") != std::string::npos) {
862+
++qs.n_ffn_up_shexp;
796863
}
797864

798865
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
799866
}
800867

868+
GGML_ASSERT(qs.n_ffn_down_exp != 0);
869+
801870
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
802871

803872
// sanity checks for models that have attention layers

0 commit comments

Comments
 (0)