Skip to content

Commit a0ad75f

Browse files
committed
ATTN_NORM and FFN_NORM quantizable with specify
1 parent 73dd5b3 commit a0ad75f

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

examples/quantize/quantize.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
102102
//
103103
[[noreturn]]
104104
static void usage(const char * executable) {
105-
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--attn-q-type] [--attn-k-type] [--attn-v-type] [--attn-qkv-type] [--attn-output-type] [--ffn-gate-type] [--ffn-down-type] [--ffn-up-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
105+
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--attn-q-type] [--attn-k-type] [--attn-v-type] [--attn-qkv-type] [--attn-output-type] [--attn-norm-type] [--ffn-norm-type] [--ffn-gate-type] [--ffn-down-type] [--ffn-up-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
106106
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
107107
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
108108
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
@@ -117,6 +117,8 @@ static void usage(const char * executable) {
117117
printf(" --attn-v-type ggml_type: use this ggml_type for the attn_v.weight tensor.\n");
118118
printf(" --attn-qkv-type ggml_type: use this ggml_type for the attn_qkv.weight tensor.\n");
119119
printf(" --attn-output-type ggml_type: use this ggml_type for the attn_output.weight tensor.\n");
120+
printf(" --attn-norm-type ggml_type: use this ggml_type instead of F32 for the tiny attn_norm.weight tensor.\n");
121+
printf(" --ffn-norm-type ggml_type: use this ggml_type instead of F32 for the tiny ffn_norm tensor.\n");
120122
printf(" --ffn-gate-type ggml_type: use this ggml_type for the ffn_gate tensor.\n");
121123
printf(" --ffn-down-type ggml_type: use this ggml_type for the ffn_down tensor.\n");
122124
printf(" --ffn-up-type ggml_type: use this ggml_type for the ffn_up tensor.\n\n");
@@ -314,6 +316,18 @@ int main(int argc, char ** argv) {
314316
} else {
315317
usage(argv[0]);
316318
}
319+
} else if (strcmp(argv[arg_idx], "--attn-norm-type") == 0) {
320+
if (arg_idx < argc-1) {
321+
params.attn_norm_type = parse_ggml_type(argv[++arg_idx]);
322+
} else {
323+
usage(argv[0]);
324+
}
325+
} else if (strcmp(argv[arg_idx], "--ffn-norm-type") == 0) {
326+
if (arg_idx < argc-1) {
327+
params.ffn_norm_type = parse_ggml_type(argv[++arg_idx]);
328+
} else {
329+
usage(argv[0]);
330+
}
317331
} else if (strcmp(argv[arg_idx], "--ffn-gate-type") == 0) {
318332
if (arg_idx < argc-1) {
319333
params.ffn_gate_type = parse_ggml_type(argv[++arg_idx]);

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ extern "C" {
358358
enum ggml_type attn_v_type; // attention value tensor type
359359
enum ggml_type attn_qkv_type; // attention query-key-value tensor type
360360
enum ggml_type attn_output_type; // attention output tensor type
361+
enum ggml_type attn_norm_type; // attention normalization tensor type
362+
enum ggml_type ffn_norm_type; // feedforward network normalization type
361363
enum ggml_type ffn_gate_type; // feedforward network gate type
362364
enum ggml_type ffn_down_type; // feedforward network down type
363365
enum ggml_type ffn_up_type; // feedforward network up type

src/llama.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16221,6 +16221,12 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
1622116221
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XL && (use_more_bits(i_layer, n_layer))) new_type = GGML_TYPE_IQ4_XS;
1622216222
++qs.i_ffn_up;
1622316223
}
16224+
else if (name.find("attn_norm.weight") != std::string::npos) {
16225+
if (ftype == LLAMA_FTYPE_CQS && qs.params->attn_norm_type < GGML_TYPE_COUNT) new_type = qs.params->attn_norm_type;
16226+
}
16227+
else if (name.find("ffn_norm") != std::string::npos) {
16228+
if (ftype == LLAMA_FTYPE_CQS && qs.params->ffn_norm_type < GGML_TYPE_COUNT) new_type = qs.params->ffn_norm_type;
16229+
}
1622416230

1622516231
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
1622616232
//}
@@ -16649,6 +16655,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
1664916655
if (params->attn_output_type < GGML_TYPE_COUNT && strcmp(tensor->name, "attn_output.weight") == 0) {
1665016656
new_type = params->attn_output_type;
1665116657
}
16658+
if (params->attn_norm_type < GGML_TYPE_COUNT && strcmp(tensor->name, "attn_norm.weight") == 0) {
16659+
new_type = params->attn_norm_type;
16660+
}
16661+
if (params->ffn_norm_type < GGML_TYPE_COUNT && strcmp(tensor->name, "ffn_norm") == 0) {
16662+
new_type = params->ffn_norm_type;
16663+
}
1665216664
if (params->ffn_gate_type < GGML_TYPE_COUNT && strcmp(tensor->name, "ffn_gate") == 0) {
1665316665
new_type = params->ffn_gate_type;
1665416666
}
@@ -17065,6 +17077,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
1706517077
/*.attn_v_type =*/ GGML_TYPE_COUNT,
1706617078
/*.attn_qkv_type =*/ GGML_TYPE_COUNT,
1706717079
/*.attn_output_type =*/ GGML_TYPE_COUNT,
17080+
/*.attn_norm_type =*/ GGML_TYPE_COUNT,
17081+
/*.ffn_norm_type =*/ GGML_TYPE_COUNT,
1706817082
/*.ffn_gate_type =*/ GGML_TYPE_COUNT,
1706917083
/*.ffn_down_type =*/ GGML_TYPE_COUNT,
1707017084
/*.ffn_up_type =*/ GGML_TYPE_COUNT,

0 commit comments

Comments
 (0)