Skip to content

Commit c07f5d7

Browse files
committed
Few more changes and tweaks
1 parent 4f90fac commit c07f5d7

File tree

2 files changed

+114
-45
lines changed

2 files changed

+114
-45
lines changed

ggml/src/ggml-common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ typedef struct {
368368
} block_iq3_xxs;
369369
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
370370

371-
// 3.4375 bpw
372371
#define IQ3S_N_SCALE QK_K/64
372+
// 3.4375 bpw
373373
typedef struct {
374374
ggml_half d;
375375
uint8_t qs[QK_K/4];

src/llama-quant.cpp

Lines changed: 113 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,22 @@ struct quantize_state_impl {
2828
int n_ffn_down = 0;
2929
int n_ffn_gate = 0;
3030
int n_ffn_up = 0;
31+
int n_ffn_down_exp = 0;
32+
int n_ffn_gate_exp = 0;
33+
int n_ffn_up_exp = 0;
34+
int n_ffn_down_shexp = 0;
35+
int n_ffn_gate_shexp = 0;
36+
int n_ffn_up_shexp = 0;
3137
int i_attention_wv = 0;
3238
int i_ffn_down = 0;
3339
int i_ffn_gate = 0;
3440
int i_ffn_up = 0;
41+
int i_ffn_down_exp = 0;
42+
int i_ffn_gate_exp = 0;
43+
int i_ffn_up_exp = 0;
44+
int i_ffn_down_shexp = 0;
45+
int i_ffn_gate_shexp = 0;
46+
int i_ffn_up_shexp = 0;
3547

3648
int n_k_quantized = 0;
3749
int n_fallback = 0;
@@ -119,6 +131,23 @@ static void llama_tensor_dequantize_impl(
119131
workers.clear();
120132
}
121133

134+
// Check if ftype is specifically IQ2_S or IQ2_M
135+
static inline bool is_iq2s_or_iq2m(llama_ftype ftype) {
136+
return ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M;
137+
}
138+
139+
// Check if ftype belongs to the IQ1 group
140+
static inline bool is_iq1_group(llama_ftype ftype) {
141+
return ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M;
142+
}
143+
144+
// Returns the appropriate type for expert _exps tensors based on ftype
145+
static inline ggml_type get_expert_exps_type(llama_ftype ftype) {
146+
if (is_iq1_group(ftype)) return GGML_TYPE_IQ2_XXS;
147+
if (is_iq2s_or_iq2m(ftype)) return GGML_TYPE_IQ3_XXS;
148+
/* otherwise */ return GGML_TYPE_IQ2_XS;
149+
}
150+
122151
static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
123152
const std::string name = ggml_get_name(tensor);
124153

@@ -175,7 +204,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
175204
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
176205
new_type = GGML_TYPE_Q2_K;
177206
}
178-
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
207+
else if (is_iq2s_or_iq2m(ftype)) {
179208
new_type = GGML_TYPE_IQ3_S;
180209
}
181210
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
@@ -189,7 +218,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
189218
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
190219
if (name.find("attn_v.weight") != std::string::npos) {
191220
if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
192-
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
221+
else new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
193222
++qs.i_attention_wv;
194223
}
195224
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k.weight") != std::string::npos) {
@@ -199,59 +228,95 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
199228
new_type = GGML_TYPE_Q4_K;
200229
}
201230
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_b.weight") != std::string::npos) {
202-
if (qs.i_attention_wv < qs.n_attention_wv/16) {
231+
if (qs.i_attention_wv < qs.n_attention_wv/8) {
203232
new_type = GGML_TYPE_Q4_K;
204233
}
205234
else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
206-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
235+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
207236
}
208237
++qs.i_attention_wv;
209238
}
210239
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_a.weight") != std::string::npos) {
211240
new_type = GGML_TYPE_Q4_K;
212241
}
213242
else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_b.weight") != std::string::npos) {
214-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
243+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
215244
}
216-
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down") != std::string::npos) {
245+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down.weight") != std::string::npos) {
217246
if (qs.i_ffn_down < qs.n_ffn_down/16) {
218247
new_type = GGML_TYPE_Q4_K;
219248
}
220249
else if (qs.i_ffn_down < qs.n_ffn_down/8) {
221-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
250+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
222251
}
223252
++qs.i_ffn_down;
224253
}
225-
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate") != std::string::npos) {
254+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate.weight") != std::string::npos) {
226255
if (qs.i_ffn_gate < qs.n_ffn_gate/16) {
227256
new_type = GGML_TYPE_Q4_K;
228257
}
229-
else if (qs.i_ffn_gate < qs.n_ffn_gate/8 || qs.i_ffn_gate >= 7*qs.n_ffn_gate/8) {
230-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
258+
else if (qs.i_ffn_gate < qs.n_ffn_gate/8) {
259+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
231260
}
232261
++qs.i_ffn_gate;
233262
}
234-
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up") != std::string::npos) {
263+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up.weight") != std::string::npos) {
235264
if (qs.i_ffn_up < qs.n_ffn_up/16) {
236265
new_type = GGML_TYPE_Q4_K;
237266
}
238267
else if (qs.i_ffn_up < qs.n_ffn_up/8) {
239-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
268+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
240269
}
241270
++qs.i_ffn_up;
242271
}
272+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down_exps.weight") != std::string::npos) {
273+
if (qs.i_ffn_down_exp < qs.n_ffn_down_exp/8) {
274+
new_type = get_expert_exps_type(ftype);
275+
}
276+
++qs.i_ffn_down_exp;
277+
}
278+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate_exps.weight") != std::string::npos) {
279+
if (qs.i_ffn_gate_exp < qs.n_ffn_gate_exp/8) {
280+
new_type = get_expert_exps_type(ftype);
281+
}
282+
++qs.i_ffn_gate_exp;
283+
}
284+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up_exps.weight") != std::string::npos) {
285+
if (qs.i_ffn_up_exp < qs.n_ffn_up_exp/8) {
286+
new_type = get_expert_exps_type(ftype);
287+
}
288+
++qs.i_ffn_up_exp;
289+
}
290+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down_shexp.weight") != std::string::npos) {
291+
if (use_more_bits(qs.i_ffn_down_shexp, qs.n_ffn_down_shexp)) {
292+
new_type = GGML_TYPE_Q4_K;
293+
}
294+
++qs.i_ffn_down_shexp;
295+
}
296+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_gate_shexp.weight") != std::string::npos) {
297+
if (use_more_bits(qs.i_ffn_gate_shexp, qs.n_ffn_gate_shexp)) {
298+
new_type = GGML_TYPE_Q4_K;
299+
}
300+
++qs.i_ffn_gate_shexp;
301+
}
302+
else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_up_shexp.weight") != std::string::npos) {
303+
if (use_more_bits(qs.i_ffn_up_shexp, qs.n_ffn_up_shexp)) {
304+
new_type = GGML_TYPE_Q4_K;
305+
}
306+
++qs.i_ffn_up_shexp;
307+
}
243308
else if (name.find("ffn_down") != std::string::npos) {
244309
if (qs.i_ffn_down < qs.n_ffn_down/8) {
245-
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
310+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
246311
}
247312
++qs.i_ffn_down;
248313
}
249314
else if (name.find("attn_output.weight") != std::string::npos) {
250315
if (qs.model.hparams.n_expert >= 8) {
251-
new_type = GGML_TYPE_Q5_K;
316+
new_type = is_iq2s_or_iq2m(ftype) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
252317
} else {
253-
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
254-
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
318+
if (is_iq1_group(ftype)) new_type = GGML_TYPE_IQ2_XXS;
319+
else if (is_iq2s_or_iq2m(ftype)) new_type = GGML_TYPE_IQ3_S;
255320
}
256321
}
257322
} else if (name.find("attn_v.weight") != std::string::npos) {
@@ -398,38 +463,28 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
398463
new_type = GGML_TYPE_IQ3_XXS;
399464
}
400465
++qs.i_ffn_up;
401-
} else if (name.find("attn_kv_a_mqa") != std::string::npos) {
402-
if (qs.model.hparams.n_expert >= 8) {
466+
} else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_a_mqa.weight") != std::string::npos) {
467+
new_type = GGML_TYPE_Q8_0;
468+
} else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_b.weight") != std::string::npos) {
469+
new_type = GGML_TYPE_Q4_K;
470+
if (qs.i_attention_wv < qs.n_attention_wv/16) {
403471
new_type = GGML_TYPE_Q8_0;
472+
} else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
473+
new_type = GGML_TYPE_Q6_K;
404474
}
405-
} else if (name.find("attn_kv_b.weight") != std::string::npos) {
406-
if (qs.model.hparams.n_expert >= 8) {
407-
new_type = GGML_TYPE_Q4_K;
408-
if (qs.i_attention_wv < qs.n_attention_wv/16) {
409-
new_type = GGML_TYPE_Q8_0;
410-
} else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
411-
new_type = GGML_TYPE_Q6_K;
412-
}
413-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
414-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
415-
}
475+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) new_type = GGML_TYPE_Q5_K;
476+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
416477
++qs.i_attention_wv;
417-
} else if (name.find("attn_q_b.weight") != std::string::npos) {
418-
if (qs.model.hparams.n_expert >= 8) {
419-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
420-
new_type = GGML_TYPE_Q4_K;
421-
}
422-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
423-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
424-
}
425-
} else if (name.find("attn_q_a.weight") != std::string::npos) {
426-
if (qs.model.hparams.n_expert >= 8) {
427-
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
428-
new_type = GGML_TYPE_Q4_K;
429-
}
430-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
431-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
478+
} else if (qs.model.hparams.n_expert >= 8 &&name.find("attn_q_b.weight") != std::string::npos) {
479+
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
480+
new_type = GGML_TYPE_Q4_K;
432481
}
482+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q5_K;
483+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
484+
} else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_a.weight") != std::string::npos) {
485+
new_type = GGML_TYPE_Q5_K;
486+
if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q6_K;
487+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q8_0;
433488
}
434489

435490
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
@@ -695,9 +750,23 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
695750
++qs.n_attention_wv;
696751
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
697752
qs.has_output = true;
753+
} else if (name.find("ffn_gate_exps.weight") != std::string::npos) {
754+
++qs.n_ffn_gate_exp;
755+
} else if (name.find("ffn_gate_shexp.weight") != std::string::npos) {
756+
++qs.n_ffn_gate_shexp;
757+
} else if (name.find("ffn_down_exps.weight") != std::string::npos) {
758+
++qs.n_ffn_down_exp;
759+
} else if (name.find("ffn_down_shexp.weight") != std::string::npos) {
760+
++qs.n_ffn_down_shexp;
761+
} else if (name.find("ffn_up_exps.weight") != std::string::npos) {
762+
++qs.n_ffn_up_exp;
763+
} else if (name.find("ffn_up_shexp.weight") != std::string::npos) {
764+
++qs.n_ffn_up_shexp;
698765
}
699766
}
700767

768+
GGML_ASSERT(qs.n_ffn_down_exp != 0);
769+
701770
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
702771

703772
// sanity checks for models that have attention layers

0 commit comments

Comments
 (0)