@@ -28,10 +28,22 @@ struct quantize_state_impl {
2828 int n_ffn_down = 0 ;
2929 int n_ffn_gate = 0 ;
3030 int n_ffn_up = 0 ;
31+ int n_ffn_down_exp = 0 ;
32+ int n_ffn_gate_exp = 0 ;
33+ int n_ffn_up_exp = 0 ;
34+ int n_ffn_down_shexp = 0 ;
35+ int n_ffn_gate_shexp = 0 ;
36+ int n_ffn_up_shexp = 0 ;
3137 int i_attention_wv = 0 ;
3238 int i_ffn_down = 0 ;
3339 int i_ffn_gate = 0 ;
3440 int i_ffn_up = 0 ;
41+ int i_ffn_down_exp = 0 ;
42+ int i_ffn_gate_exp = 0 ;
43+ int i_ffn_up_exp = 0 ;
44+ int i_ffn_down_shexp = 0 ;
45+ int i_ffn_gate_shexp = 0 ;
46+ int i_ffn_up_shexp = 0 ;
3547
3648 int n_k_quantized = 0 ;
3749 int n_fallback = 0 ;
@@ -119,6 +131,23 @@ static void llama_tensor_dequantize_impl(
119131 workers.clear ();
120132}
121133
134+ // Check if ftype is specifically IQ2_S or IQ2_M
135+ static inline bool is_iq2s_or_iq2m (llama_ftype ftype) {
136+ return ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M;
137+ }
138+
139+ // Check if ftype belongs to the IQ1 group
140+ static inline bool is_iq1_group (llama_ftype ftype) {
141+ return ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M;
142+ }
143+
144+ // Returns the appropriate type for expert _exps tensors based on ftype
145+ static inline ggml_type get_expert_exps_type (llama_ftype ftype) {
146+ if (is_iq1_group (ftype)) return GGML_TYPE_IQ2_XXS;
147+ if (is_iq2s_or_iq2m (ftype)) return GGML_TYPE_IQ3_XXS;
148+ /* otherwise */ return GGML_TYPE_IQ2_XS;
149+ }
150+
122151static ggml_type llama_tensor_get_type (quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
123152 const std::string name = ggml_get_name (tensor);
124153
@@ -175,7 +204,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
175204 ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
176205 new_type = GGML_TYPE_Q2_K;
177206 }
178- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ) {
207+ else if (is_iq2s_or_iq2m ( ftype) ) {
179208 new_type = GGML_TYPE_IQ3_S;
180209 }
181210 else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
@@ -189,7 +218,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
189218 ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
190219 if (name.find (" attn_v.weight" ) != std::string::npos) {
191220 if (qs.model .hparams .n_gqa () >= 4 || qs.model .hparams .n_expert >= 4 ) new_type = GGML_TYPE_Q4_K;
192- else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
221+ else new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
193222 ++qs.i_attention_wv ;
194223 }
195224 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_k.weight" ) != std::string::npos) {
@@ -199,59 +228,95 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
199228 new_type = GGML_TYPE_Q4_K;
200229 }
201230 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_kv_b.weight" ) != std::string::npos) {
202- if (qs.i_attention_wv < qs.n_attention_wv /16 ) {
231+ if (qs.i_attention_wv < qs.n_attention_wv /8 ) {
203232 new_type = GGML_TYPE_Q4_K;
204233 }
205234 else if (use_more_bits (qs.i_attention_wv , qs.n_attention_wv )) {
206- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
235+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
207236 }
208237 ++qs.i_attention_wv ;
209238 }
210239 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_q_a.weight" ) != std::string::npos) {
211240 new_type = GGML_TYPE_Q4_K;
212241 }
213242 else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_q_b.weight" ) != std::string::npos) {
214- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
243+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
215244 }
216- else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down" ) != std::string::npos) {
245+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down.weight " ) != std::string::npos) {
217246 if (qs.i_ffn_down < qs.n_ffn_down /16 ) {
218247 new_type = GGML_TYPE_Q4_K;
219248 }
220249 else if (qs.i_ffn_down < qs.n_ffn_down /8 ) {
221- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
250+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
222251 }
223252 ++qs.i_ffn_down ;
224253 }
225- else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate" ) != std::string::npos) {
254+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate.weight " ) != std::string::npos) {
226255 if (qs.i_ffn_gate < qs.n_ffn_gate /16 ) {
227256 new_type = GGML_TYPE_Q4_K;
228257 }
229- else if (qs.i_ffn_gate < qs.n_ffn_gate /8 || qs. i_ffn_gate >= 7 *qs. n_ffn_gate / 8 ) {
230- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
258+ else if (qs.i_ffn_gate < qs.n_ffn_gate /8 ) {
259+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
231260 }
232261 ++qs.i_ffn_gate ;
233262 }
234- else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up" ) != std::string::npos) {
263+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up.weight " ) != std::string::npos) {
235264 if (qs.i_ffn_up < qs.n_ffn_up /16 ) {
236265 new_type = GGML_TYPE_Q4_K;
237266 }
238267 else if (qs.i_ffn_up < qs.n_ffn_up /8 ) {
239- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
268+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
240269 }
241270 ++qs.i_ffn_up ;
242271 }
272+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down_exps.weight" ) != std::string::npos) {
273+ if (qs.i_ffn_down_exp < qs.n_ffn_down_exp /8 ) {
274+ new_type = get_expert_exps_type (ftype);
275+ }
276+ ++qs.i_ffn_down_exp ;
277+ }
278+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate_exps.weight" ) != std::string::npos) {
279+ if (qs.i_ffn_gate_exp < qs.n_ffn_gate_exp /8 ) {
280+ new_type = get_expert_exps_type (ftype);
281+ }
282+ ++qs.i_ffn_gate_exp ;
283+ }
284+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up_exps.weight" ) != std::string::npos) {
285+ if (qs.i_ffn_up_exp < qs.n_ffn_up_exp /8 ) {
286+ new_type = get_expert_exps_type (ftype);
287+ }
288+ ++qs.i_ffn_up_exp ;
289+ }
290+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_down_shexp.weight" ) != std::string::npos) {
291+ if (use_more_bits (qs.i_ffn_down_shexp , qs.n_ffn_down_shexp )) {
292+ new_type = GGML_TYPE_Q4_K;
293+ }
294+ ++qs.i_ffn_down_shexp ;
295+ }
296+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_gate_shexp.weight" ) != std::string::npos) {
297+ if (use_more_bits (qs.i_ffn_gate_shexp , qs.n_ffn_gate_shexp )) {
298+ new_type = GGML_TYPE_Q4_K;
299+ }
300+ ++qs.i_ffn_gate_shexp ;
301+ }
302+ else if (qs.model .hparams .n_expert >= 8 && name.find (" ffn_up_shexp.weight" ) != std::string::npos) {
303+ if (use_more_bits (qs.i_ffn_up_shexp , qs.n_ffn_up_shexp )) {
304+ new_type = GGML_TYPE_Q4_K;
305+ }
306+ ++qs.i_ffn_up_shexp ;
307+ }
243308 else if (name.find (" ffn_down" ) != std::string::npos) {
244309 if (qs.i_ffn_down < qs.n_ffn_down /8 ) {
245- new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
310+ new_type = is_iq2s_or_iq2m ( ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
246311 }
247312 ++qs.i_ffn_down ;
248313 }
249314 else if (name.find (" attn_output.weight" ) != std::string::npos) {
250315 if (qs.model .hparams .n_expert >= 8 ) {
251- new_type = GGML_TYPE_Q5_K;
316+ new_type = is_iq2s_or_iq2m (ftype) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K ;
252317 } else {
253- if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M ) new_type = GGML_TYPE_IQ2_XXS;
254- else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ) new_type = GGML_TYPE_IQ3_S;
318+ if (is_iq1_group ( ftype) ) new_type = GGML_TYPE_IQ2_XXS;
319+ else if (is_iq2s_or_iq2m ( ftype) ) new_type = GGML_TYPE_IQ3_S;
255320 }
256321 }
257322 } else if (name.find (" attn_v.weight" ) != std::string::npos) {
@@ -398,38 +463,28 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
398463 new_type = GGML_TYPE_IQ3_XXS;
399464 }
400465 ++qs.i_ffn_up ;
401- } else if (name.find (" attn_kv_a_mqa" ) != std::string::npos) {
402- if (qs.model .hparams .n_expert >= 8 ) {
466+ } else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_kv_a_mqa.weight" ) != std::string::npos) {
467+ new_type = GGML_TYPE_Q8_0;
468+ } else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_kv_b.weight" ) != std::string::npos) {
469+ new_type = GGML_TYPE_Q4_K;
470+ if (qs.i_attention_wv < qs.n_attention_wv /16 ) {
403471 new_type = GGML_TYPE_Q8_0;
472+ } else if (use_more_bits (qs.i_attention_wv , qs.n_attention_wv )) {
473+ new_type = GGML_TYPE_Q6_K;
404474 }
405- } else if (name.find (" attn_kv_b.weight" ) != std::string::npos) {
406- if (qs.model .hparams .n_expert >= 8 ) {
407- new_type = GGML_TYPE_Q4_K;
408- if (qs.i_attention_wv < qs.n_attention_wv /16 ) {
409- new_type = GGML_TYPE_Q8_0;
410- } else if (use_more_bits (qs.i_attention_wv , qs.n_attention_wv )) {
411- new_type = GGML_TYPE_Q6_K;
412- }
413- else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
414- else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
415- }
475+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
476+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
416477 ++qs.i_attention_wv ;
417- } else if (name.find (" attn_q_b.weight" ) != std::string::npos) {
418- if (qs.model .hparams .n_expert >= 8 ) {
419- if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
420- new_type = GGML_TYPE_Q4_K;
421- }
422- else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
423- else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
424- }
425- } else if (name.find (" attn_q_a.weight" ) != std::string::npos) {
426- if (qs.model .hparams .n_expert >= 8 ) {
427- if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
428- new_type = GGML_TYPE_Q4_K;
429- }
430- else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
431- else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
478+ } else if (qs.model .hparams .n_expert >= 8 &&name.find (" attn_q_b.weight" ) != std::string::npos) {
479+ if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
480+ new_type = GGML_TYPE_Q4_K;
432481 }
482+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q5_K;
483+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
484+ } else if (qs.model .hparams .n_expert >= 8 && name.find (" attn_q_a.weight" ) != std::string::npos) {
485+ new_type = GGML_TYPE_Q5_K;
486+ if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q6_K;
487+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q8_0;
433488 }
434489
435490 // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
@@ -695,9 +750,23 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
695750 ++qs.n_attention_wv ;
696751 } else if (name == LLM_TN (model.arch )(LLM_TENSOR_OUTPUT, " weight" )) {
697752 qs.has_output = true ;
753+ } else if (name.find (" ffn_gate_exps.weight" ) != std::string::npos) {
754+ ++qs.n_ffn_gate_exp ;
755+ } else if (name.find (" ffn_gate_shexp.weight" ) != std::string::npos) {
756+ ++qs.n_ffn_gate_shexp ;
757+ } else if (name.find (" ffn_down_exps.weight" ) != std::string::npos) {
758+ ++qs.n_ffn_down_exp ;
759+ } else if (name.find (" ffn_down_shexp.weight" ) != std::string::npos) {
760+ ++qs.n_ffn_down_shexp ;
761+ } else if (name.find (" ffn_up_exps.weight" ) != std::string::npos) {
762+ ++qs.n_ffn_up_exp ;
763+ } else if (name.find (" ffn_up_shexp.weight" ) != std::string::npos) {
764+ ++qs.n_ffn_up_shexp ;
698765 }
699766 }
700767
768+ GGML_ASSERT (qs.n_ffn_down_exp != 0 );
769+
701770 qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int )model.hparams .n_layer ;
702771
703772 // sanity checks for models that have attention layers
0 commit comments