Skip to content

Commit e92db00

Browse files
committed
Refactor quantisation checks into its own function
1 parent 814f6b6 commit e92db00

File tree

1 file changed

+57
-83
lines changed

1 file changed

+57
-83
lines changed

src/llama-quant.cpp

Lines changed: 57 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,60 @@ struct tensor_quantization {
2121
ggml_type quant = GGML_TYPE_COUNT;
2222
};
2323

24+
static bool is_quantizable(const std::string & name, const llm_arch arch, const llama_model_quantize_params * params) {
25+
if (params->only_copy) { return false; }
26+
27+
const auto tn = LLM_TN(arch);
28+
29+
// This used to be a regex, but <regex> has an extreme cost to compile times.
30+
bool q = name.size() >= 6 && name.rfind("weight") == name.size() - 6; // ends with 'weight'?
31+
32+
// Do not quantize norm tensors
33+
q &= name.find("_norm.weight") == std::string::npos;
34+
35+
// Do not quantize expert gating tensors
36+
// NOTE: can't use LLM_TN here because the layer number is not known
37+
q &= name.find("ffn_gate_inp.weight") == std::string::npos;
38+
39+
// These are very small (e.g. 4x4)
40+
q &= name.find("altup") == std::string::npos;
41+
q &= name.find("laurel") == std::string::npos;
42+
43+
// These are not too big so keep them as it is
44+
q &= name.find("per_layer_model_proj") == std::string::npos;
45+
46+
// Do not quantize positional embeddings and token types (BERT)
47+
q &= name != tn(LLM_TENSOR_POS_EMBD, "weight");
48+
q &= name != tn(LLM_TENSOR_TOKEN_TYPES, "weight");
49+
50+
// Do not quantize Jamba, Mamba, LFM2's small yet 2D weights
51+
// NOTE: can't use LLM_TN here because the layer number is not known
52+
q &= name.find("ssm_conv1d.weight") == std::string::npos;
53+
q &= name.find("shortconv.conv.weight") == std::string::npos;
54+
55+
// Do not quantize ARWKV, RWKV's small yet 2D weights
56+
q &= name.find("time_mix_first.weight") == std::string::npos;
57+
q &= name.find("time_mix_w0.weight") == std::string::npos;
58+
q &= name.find("time_mix_w1.weight") == std::string::npos;
59+
q &= name.find("time_mix_w2.weight") == std::string::npos;
60+
q &= name.find("time_mix_v0.weight") == std::string::npos;
61+
q &= name.find("time_mix_v1.weight") == std::string::npos;
62+
q &= name.find("time_mix_v2.weight") == std::string::npos;
63+
q &= name.find("time_mix_a0.weight") == std::string::npos;
64+
q &= name.find("time_mix_a1.weight") == std::string::npos;
65+
q &= name.find("time_mix_a2.weight") == std::string::npos;
66+
q &= name.find("time_mix_g1.weight") == std::string::npos;
67+
q &= name.find("time_mix_g2.weight") == std::string::npos;
68+
q &= name.find("time_mix_decay_w1.weight") == std::string::npos;
69+
q &= name.find("time_mix_decay_w2.weight") == std::string::npos;
70+
q &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
71+
72+
// Do not quantize relative position bias (T5)
73+
q &= name.find("attn_rel_b.weight") == std::string::npos;
74+
75+
return q;
76+
}
77+
2478
static bool is_iq(const enum ggml_type t) {
2579
switch (t) {
2680
case GGML_TYPE_IQ1_S:
@@ -684,40 +738,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
684738
return is_compatible(t, fb) ? fb : GGML_TYPE_F16;
685739
};
686740

687-
auto name_tn = LLM_TN(model.arch);
688741
auto can_quantize = [&](const ggml_tensor * t) -> bool {
689-
// This list should be kept in sync with llama_tensor_quantize_impl() to avoid drift
690-
const std::string name = ggml_get_name(t);
691-
bool q = name.rfind("weight") == name.size() - 6;
692-
q &= ggml_n_dims(t) >= 2;
693-
q &= name.find("_norm.weight") == std::string::npos;
694-
q &= name.find("ffn_gate_inp.weight") == std::string::npos;
695-
q &= name.find("altup") == std::string::npos;
696-
q &= name.find("laurel") == std::string::npos;
697-
q &= name.find("per_layer_model_proj") == std::string::npos;
698-
q &= name != name_tn(LLM_TENSOR_POS_EMBD, "weight");
699-
q &= name != name_tn(LLM_TENSOR_TOKEN_TYPES, "weight");
700-
q &= name.find("ssm_conv1d.weight") == std::string::npos;
701-
q &= name.find("shortconv.conv.weight") == std::string::npos;
702-
q &= name.find("time_mix_first.weight") == std::string::npos;
703-
q &= name.find("time_mix_w0.weight") == std::string::npos;
704-
q &= name.find("time_mix_w1.weight") == std::string::npos;
705-
q &= name.find("time_mix_w2.weight") == std::string::npos;
706-
q &= name.find("time_mix_v0.weight") == std::string::npos;
707-
q &= name.find("time_mix_v1.weight") == std::string::npos;
708-
q &= name.find("time_mix_v2.weight") == std::string::npos;
709-
q &= name.find("time_mix_a0.weight") == std::string::npos;
710-
q &= name.find("time_mix_a1.weight") == std::string::npos;
711-
q &= name.find("time_mix_a2.weight") == std::string::npos;
712-
q &= name.find("time_mix_g1.weight") == std::string::npos;
713-
q &= name.find("time_mix_g2.weight") == std::string::npos;
714-
q &= name.find("time_mix_decay_w1.weight") == std::string::npos;
715-
q &= name.find("time_mix_decay_w2.weight") == std::string::npos;
716-
q &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
717-
q &= name.find("attn_rel_b.weight") == std::string::npos;
718-
q &= !params->only_copy;
719-
720-
return q;
742+
if (ggml_n_dims(t) < 2) { return false; }
743+
return is_quantizable(ggml_get_name(t), model.arch, params);
721744
};
722745

723746
// Estimate error for a given type using a sampled subset of rows
@@ -1747,57 +1770,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
17471770
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
17481771
++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type));
17491772

1750-
// This used to be a regex, but <regex> has an extreme cost to compile times.
1751-
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
1752-
1753-
// quantize only 2D and 3D tensors (experts)
1754-
quantize &= (ggml_n_dims(tensor) >= 2);
1755-
1756-
// do not quantize norm tensors
1757-
quantize &= name.find("_norm.weight") == std::string::npos;
1758-
1773+
bool quantize = ggml_n_dims(tensor) >= 2 && is_quantizable(name, model.arch, params);
17591774
quantize &= params->quantize_output_tensor || name != "output.weight";
1760-
quantize &= !params->only_copy;
1761-
1762-
// do not quantize expert gating tensors
1763-
// NOTE: can't use LLM_TN here because the layer number is not known
1764-
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
1765-
1766-
// these are very small (e.g. 4x4)
1767-
quantize &= name.find("altup") == std::string::npos;
1768-
quantize &= name.find("laurel") == std::string::npos;
1769-
1770-
// these are not too big so keep them as it is
1771-
quantize &= name.find("per_layer_model_proj") == std::string::npos;
1772-
1773-
// do not quantize positional embeddings and token types (BERT)
1774-
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
1775-
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
1776-
1777-
// do not quantize Mamba's small yet 2D weights
1778-
// NOTE: can't use LLM_TN here because the layer number is not known
1779-
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
1780-
quantize &= name.find("shortconv.conv.weight") == std::string::npos;
1781-
1782-
// do not quantize RWKV's small yet 2D weights
1783-
quantize &= name.find("time_mix_first.weight") == std::string::npos;
1784-
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
1785-
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
1786-
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
1787-
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
1788-
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
1789-
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
1790-
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
1791-
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
1792-
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
1793-
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
1794-
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
1795-
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
1796-
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
1797-
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
1798-
1799-
// do not quantize relative position bias (T5)
1800-
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
18011775

18021776
ggml_type new_type;
18031777
void * new_data;

0 commit comments

Comments
 (0)