@@ -84,10 +84,22 @@ struct quantize_state_impl {
8484 int n_ffn_down = 0 ;
8585 int n_ffn_gate = 0 ;
8686 int n_ffn_up = 0 ;
87+ int n_ffn_down_exp = 0 ;
88+ int n_ffn_gate_exp = 0 ;
89+ int n_ffn_up_exp = 0 ;
90+ int n_ffn_down_shexp = 0 ;
91+ int n_ffn_gate_shexp = 0 ;
92+ int n_ffn_up_shexp = 0 ;
8793 int i_attention_wv = 0 ;
8894 int i_ffn_down = 0 ;
8995 int i_ffn_gate = 0 ;
9096 int i_ffn_up = 0 ;
97+ int i_ffn_down_exp = 0 ;
98+ int i_ffn_gate_exp = 0 ;
99+ int i_ffn_up_exp = 0 ;
100+ int i_ffn_down_shexp = 0 ;
101+ int i_ffn_gate_shexp = 0 ;
102+ int i_ffn_up_shexp = 0 ;
91103
92104 int n_k_quantized = 0 ;
93105 int n_fallback = 0 ;
@@ -175,6 +187,23 @@ static void llama_tensor_dequantize_impl(
175187 workers.clear ();
176188}
177189
190+ // Check if ftype is specifically IQ2_S or IQ2_M
191+ static inline bool is_iq2s_or_iq2m (llama_ftype ftype) {
192+ return ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M;
193+ }
194+
195+ // Check if ftype belongs to the IQ1 group
196+ static inline bool is_iq1_group (llama_ftype ftype) {
197+ return ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M;
198+ }
199+
200+ // Returns the appropriate type for expert _exps tensors based on ftype
201+ static inline ggml_type get_expert_exps_type (llama_ftype ftype) {
202+ if (is_iq1_group (ftype)) return GGML_TYPE_IQ2_XXS;
203+ if (is_iq2s_or_iq2m (ftype)) return GGML_TYPE_IQ3_XXS;
204+ /* otherwise */ return GGML_TYPE_IQ2_XS;
205+ }
206+
178207static ggml_type llama_tensor_get_type (quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
179208 const std::string name = ggml_get_name (tensor);
180209
@@ -242,7 +271,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
242271 ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
243272 new_type = GGML_TYPE_Q2_K;
244273 }
245- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ) {
274+ else if (is_iq2s_or_iq2m ( ftype) ) {
246275 new_type = GGML_TYPE_IQ3_S;
247276 }
248277 else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
@@ -256,7 +285,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
256285 ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
257286 if (name.find (" attn_v.weight" ) != std::string::npos) {
258287 if (qs.model .hparams .n_gqa () >= 4 || qs.model .hparams .n_expert >= 4 ) new_type = GGML_TYPE_Q4_K;
259- else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
288+ else new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
260289 ++qs.i_attention_wv ;
261290 }
262291 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_k.weight" ) != std::string::npos) {
@@ -266,59 +295,95 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
266295 new_type = GGML_TYPE_Q4_K;
267296 }
268297 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_kv_b.weight" ) != std::string::npos) {
269- if (qs.i_attention_wv < qs.n_attention_wv /16 ) {
298+ if (qs.i_attention_wv < qs.n_attention_wv /8 ) {
270299 new_type = GGML_TYPE_Q4_K;
271300 }
272301 else if (use_more_bits (qs.i_attention_wv , qs.n_attention_wv )) {
273- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
302+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
274303 }
275304 ++qs.i_attention_wv ;
276305 }
277306 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_q_a.weight" ) != std::string::npos) {
278307 new_type = GGML_TYPE_Q4_K;
279308 }
280309 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_q_b.weight" ) != std::string::npos) {
281- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
310+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
282311 }
283- else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down" ) != std::string::npos) {
312+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down.weight " ) != std::string::npos) {
284313 if (qs.i_ffn_down < qs.n_ffn_down /16 ) {
285314 new_type = GGML_TYPE_Q4_K;
286315 }
287316 else if (qs.i_ffn_down < qs.n_ffn_down /8 ) {
288- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
317+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
289318 }
290319 ++qs.i_ffn_down ;
291320 }
292- else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate" ) != std::string::npos) {
321+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate.weight " ) != std::string::npos) {
293322 if (qs.i_ffn_gate < qs.n_ffn_gate /16 ) {
294323 new_type = GGML_TYPE_Q4_K;
295324 }
296- else if (qs.i_ffn_gate < qs.n_ffn_gate /8 || qs. i_ffn_gate >= 7 *qs. n_ffn_gate / 8 ) {
297- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
325+ else if (qs.i_ffn_gate < qs.n_ffn_gate /8 ) {
326+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
298327 }
299328 ++qs.i_ffn_gate ;
300329 }
301- else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up" ) != std::string::npos) {
330+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up.weight " ) != std::string::npos) {
302331 if (qs.i_ffn_up < qs.n_ffn_up /16 ) {
303332 new_type = GGML_TYPE_Q4_K;
304333 }
305334 else if (qs.i_ffn_up < qs.n_ffn_up /8 ) {
306- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
335+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
307336 }
308337 ++qs.i_ffn_up ;
309338 }
339+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down_exps.weight" ) != std::string::npos) {
340+ if (qs.i_ffn_down_exp < qs.n_ffn_down_exp /8 ) {
341+ new_type = get_expert_exps_type (ftype);
342+ }
343+ ++qs.i_ffn_down_exp ;
344+ }
345+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate_exps.weight" ) != std::string::npos) {
346+ if (qs.i_ffn_gate_exp < qs.n_ffn_gate_exp /8 ) {
347+ new_type = get_expert_exps_type (ftype);
348+ }
349+ ++qs.i_ffn_gate_exp ;
350+ }
351+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up_exps.weight" ) != std::string::npos) {
352+ if (qs.i_ffn_up_exp < qs.n_ffn_up_exp /8 ) {
353+ new_type = get_expert_exps_type (ftype);
354+ }
355+ ++qs.i_ffn_up_exp ;
356+ }
357+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down_shexp.weight" ) != std::string::npos) {
358+ if (use_more_bits (qs.i_ffn_down_shexp , qs.n_ffn_down_shexp )) {
359+ new_type = GGML_TYPE_Q4_K;
360+ }
361+ ++qs.i_ffn_down_shexp ;
362+ }
363+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate_shexp.weight" ) != std::string::npos) {
364+ if (use_more_bits (qs.i_ffn_gate_shexp , qs.n_ffn_gate_shexp )) {
365+ new_type = GGML_TYPE_Q4_K;
366+ }
367+ ++qs.i_ffn_gate_shexp ;
368+ }
369+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up_shexp.weight" ) != std::string::npos) {
370+ if (use_more_bits (qs.i_ffn_up_shexp , qs.n_ffn_up_shexp )) {
371+ new_type = GGML_TYPE_Q4_K;
372+ }
373+ ++qs.i_ffn_up_shexp ;
374+ }
310375 else if (name.find (" ffn_down" ) != std::string::npos) {
311376 if (qs.i_ffn_down < qs.n_ffn_down /8 ) {
312- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
377+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
313378 }
314379 ++qs.i_ffn_down ;
315380 }
316381 else if (name.find (" attn_output.weight" ) != std::string::npos) {
317382 if (qs.model .hparams .n_expert >= 8 ) {
318- new_type = GGML_TYPE_Q5_K;
383+ new_type = is_iq2s_or_iq2m (ftype) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K ;
319384 } else {
320- if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M ) new_type = GGML_TYPE_IQ2_XXS;
321- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ) new_type = GGML_TYPE_IQ3_S;
385+ if (is_iq1_group ( ftype) ) new_type = GGML_TYPE_IQ2_XXS;
386+ else if (is_iq2s_or_iq2m ( ftype) ) new_type = GGML_TYPE_IQ3_S;
322387 }
323388 }
324389 } else if (name.find (" attn_v.weight" ) != std::string::npos) {
@@ -465,38 +530,28 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
465530 new_type = GGML_TYPE_IQ3_XXS;
466531 }
467532 ++qs.i_ffn_up ;
468- } else if (name.find (" attn_kv_a_mqa" ) != std::string::npos) {
469- if (qs.model .hparams .n_expert >= 8 ) {
533+ } else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_kv_a_mqa.weight" ) != std::string::npos) {
534+ new_type = GGML_TYPE_Q8_0;
535+ } else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_kv_b.weight" ) != std::string::npos) {
536+ new_type = GGML_TYPE_Q4_K;
537+ if (qs.i_attention_wv < qs.n_attention_wv /16 ) {
470538 new_type = GGML_TYPE_Q8_0;
539+ } else if (use_more_bits (qs.i_attention_wv , qs.n_attention_wv )) {
540+ new_type = GGML_TYPE_Q6_K;
471541 }
472- } else if (name.find (" attn_kv_b.weight" ) != std::string::npos) {
473- if (qs.model .hparams .n_expert >= 8 ) {
474- new_type = GGML_TYPE_Q4_K;
475- if (qs.i_attention_wv < qs.n_attention_wv /16 ) {
476- new_type = GGML_TYPE_Q8_0;
477- } else if (use_more_bits (qs.i_attention_wv , qs.n_attention_wv )) {
478- new_type = GGML_TYPE_Q6_K;
479- }
480- else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
481- else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
482- }
542+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
543+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
483544 ++qs.i_attention_wv ;
484- } else if (name.find (" attn_q_b.weight" ) != std::string::npos) {
485- if (qs.model .hparams .n_expert >= 8 ) {
486- if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
487- new_type = GGML_TYPE_Q4_K;
488- }
489- else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
490- else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
491- }
492- } else if (name.find (" attn_q_a.weight" ) != std::string::npos) {
493- if (qs.model .hparams .n_expert >= 8 ) {
494- if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
495- new_type = GGML_TYPE_Q4_K;
496- }
497- else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
498- else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
545+ } else if (qs.model .hparams .n_expert >= 8 &&name.find (" attn_q_b.weight" ) != std::string::npos) {
546+ if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
547+ new_type = GGML_TYPE_Q4_K;
499548 }
549+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q5_K;
550+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
551+ } else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_q_a.weight" ) != std::string::npos) {
552+ new_type = GGML_TYPE_Q5_K;
553+ if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q6_K;
554+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q8_0;
500555 }
501556
502557 // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
@@ -793,11 +848,25 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
793848 ++qs.n_attention_wv ;
794849 } else if (name == LLM_TN (model.arch )(LLM_TENSOR_OUTPUT, " weight" )) {
795850 qs.has_output = true ;
851+ } else if (name.find (" ffn_gate_exps.weight" ) != std::string::npos) {
852+ ++qs.n_ffn_gate_exp ;
853+ } else if (name.find (" ffn_gate_shexp.weight" ) != std::string::npos) {
854+ ++qs.n_ffn_gate_shexp ;
855+ } else if (name.find (" ffn_down_exps.weight" ) != std::string::npos) {
856+ ++qs.n_ffn_down_exp ;
857+ } else if (name.find (" ffn_down_shexp.weight" ) != std::string::npos) {
858+ ++qs.n_ffn_down_shexp ;
859+ } else if (name.find (" ffn_up_exps.weight" ) != std::string::npos) {
860+ ++qs.n_ffn_up_exp ;
861+ } else if (name.find (" ffn_up_shexp.weight" ) != std::string::npos) {
862+ ++qs.n_ffn_up_shexp ;
796863 }
797864
798865 is_clip_model |= name.rfind (" mm." , 0 ) == 0 ; // check the "mm." prefix
799866 }
800867
868+ GGML_ASSERT (qs.n_ffn_down_exp != 0 );
869+
801870 qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int )model.hparams .n_layer ;
802871
803872 // sanity checks for models that have attention layers
0 commit comments