Skip to content

Commit 54e13cf

Browse files
committed
Implement general --tensor-type instead of tensor-specific command option
1 parent 071e9ef commit 54e13cf

File tree

3 files changed

+122
-290
lines changed

3 files changed

+122
-290
lines changed

examples/quantize/quantize.cpp

Lines changed: 98 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
105105
[[noreturn]]
106106
static void usage(const char * executable) {
107107
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
108-
printf(" [--token-embedding-type] [--attention-qkv-type] [--attention-q-type] [--attention-k-type] [--attention-v-type] [--attention-qa-type]\n");
109-
printf(" [--attention-qb-type] [--attention-kva-type] [--attention-kvb-type] [--attention-output-type] [--feedforward-up-type] [--feedforward-gate-type]\n");
110-
printf(" [--feedforward-down-type] [--feedforward-gate-exp-type] [--feedforward-down-exp-type] [--feedforward-up-exp-type] [--feedforward-gate-shexp-type]\n");
111-
printf(" [--feedforward-down-shexp-type] [--feedforward-up-shexp-type] [--classifier-type] [--classifier-output-type] [--override-kv]\n");
112-
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
108+
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
113109
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
114110
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
115111
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
@@ -118,26 +114,8 @@ static void usage(const char * executable) {
118114
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
119115
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
120116
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
121-
printf(" --attention-qkv-type ggml_type: use this ggml_type for the attn_qkv.weight tensor\n");
122-
printf(" --attention-q-type ggml_type: use this ggml_type for the attn_q.weight tensor\n");
123-
printf(" --attention-k-type ggml_type: use this ggml_type for the attn_k.weight tensor\n");
124-
printf(" --attention-v-type ggml_type: use this ggml_type for the attn_v.weight tensor\n");
125-
printf(" --attention-qa-type ggml_type: use this ggml_type for the attn_q_a.weight tensor\n");
126-
printf(" --attention-qb-type ggml_type: use this ggml_type for the attn_q_b.weight tensor\n");
127-
printf(" --attention-kva-type ggml_type: use this ggml_type for the attn_kv_a_mqa.weight tensor\n");
128-
printf(" --attention-kvb-type ggml_type: use this ggml_type for the attn_kv_b.weight tensor\n");
129-
printf(" --attention-output-type ggml_type: use this ggml_type for the attn_output.weight tensor\n");
130-
printf(" --feedforward-up-type ggml_type: use this ggml_type for the ffn_up.weight tensor\n");
131-
printf(" --feedforward-gate-type ggml_type: use this ggml_type for the ffn_gate.weight tensor\n");
132-
printf(" --feedforward-down-type ggml_type: use this ggml_type for the ffn_down.weight tensor\n");
133-
printf(" --feedforward-up-exp-type ggml_type: use this ggml_type for the ffn_up_exp.weight tensor\n");
134-
printf(" --feedforward-gate-exp-type ggml_type: use this ggml_type for the ffn_gate_exp.weight tensor\n");
135-
printf(" --feedforward-down-exp-type ggml_type: use this ggml_type for the ffn_down_exp.weight tensor\n");
136-
printf(" --feedforward-up-shexp-type ggml_type: use this ggml_type for the ffn_up_shexp.weight tensor\n");
137-
printf(" --feedforward-gate-shexp-type ggml_type: use this ggml_type for the ffn_gate_shexp.weight tensor\n");
138-
printf(" --feedforward-down-shexp-type ggml_type: use this ggml_type for the ffn_down_shexp.weight tensor\n");
139-
printf(" --classifier-type ggml_type: use this ggml_type for the cls.weight tensor\n");
140-
printf(" --classifier-output-type ggml_type: use this ggml_type for the cls.output.weight tensor\n");
117+
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
118+
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
141119
printf(" --keep-split: will generate quantized model in the same shards as input\n");
142120
printf(" --override-kv KEY=TYPE:VALUE\n");
143121
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
@@ -268,6 +246,95 @@ static ggml_type parse_ggml_type(const char * arg) {
268246
return GGML_TYPE_COUNT;
269247
}
270248

249+
// Allowed tensors for arbitrary quantization with --tensor-type option
250+
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
251+
"attn_k",
252+
"attn_kv_a_mqa",
253+
"attn_kv_b",
254+
"attn_out",
255+
"attn_q_a",
256+
"attn_q_b",
257+
"attn_q",
258+
"attn_qkv",
259+
"attn_v",
260+
"channel_mix_key",
261+
"channel_mix_receptance",
262+
"channel_mix_value",
263+
"cls_out",
264+
"cls",
265+
"dec_attn_k",
266+
"dec_attn_out",
267+
"dec_attn_q",
268+
"dec_attn_v",
269+
"dec_cross_attn_k",
270+
"dec_cross_attn_out",
271+
"dec_cross_attn_q",
272+
"dec_cross_attn_v",
273+
"ffn_act",
274+
"ffn_down_exp",
275+
"ffn_down_shexp",
276+
"ffn_down",
277+
"ffn_gate_exp",
278+
"ffn_gate_shexp",
279+
"ffn_gate",
280+
"ffn_up_exp",
281+
"ffn_up_shexp",
282+
"ffn_up",
283+
"ssm_in",
284+
"ssm_out",
285+
"time_mix_gate",
286+
"time_mix_key",
287+
"time_mix_output",
288+
"time_mix_receptance",
289+
"time_mix_value",
290+
};
291+
292+
// changes to this struct must be replicated in llama-quant.cpp
293+
struct tensor_quantization {
294+
std::string name;
295+
ggml_type quant = GGML_TYPE_COUNT;
296+
};
297+
298+
static bool string_parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
299+
const char * sep = strchr(data, '=');
300+
if (sep == nullptr) {
301+
printf("\n%s: malformed tensor type '%s'\n\n", __func__, data);
302+
return false;
303+
}
304+
305+
const size_t tn_len = sep - data;
306+
if (tn_len == 0) {
307+
printf("\n%s: missing tensor name\n\n", __func__);
308+
return false;
309+
}
310+
311+
if (const size_t qt_len = strlen(sep); qt_len == 1) {
312+
printf("\n%s: missing quantization type\n\n", __func__);
313+
return false;
314+
}
315+
316+
std::string tn(data, tn_len);
317+
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
318+
sep++;
319+
const std::string qt(sep);
320+
321+
if (find(ALLOWED_TENSOR_TYPE.begin(), ALLOWED_TENSOR_TYPE.end(), tn) == ALLOWED_TENSOR_TYPE.end()) {
322+
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
323+
return false;
324+
}
325+
326+
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
327+
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
328+
return false;
329+
}
330+
331+
tensor_quantization tqz;
332+
tqz.name = tn;
333+
tqz.quant = parse_ggml_type(qt.c_str());
334+
tensor_type.emplace_back(std::move(tqz));
335+
return true;
336+
}
337+
271338
int main(int argc, char ** argv) {
272339
if (argc < 3) {
273340
usage(argv[0]);
@@ -279,6 +346,7 @@ int main(int argc, char ** argv) {
279346
std::string imatrix_file;
280347
std::vector<std::string> included_weights, excluded_weights;
281348
std::vector<llama_model_kv_override> kv_overrides;
349+
std::vector<tensor_quantization> tensor_types;
282350

283351
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
284352
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@@ -301,184 +369,8 @@ int main(int argc, char ** argv) {
301369
} else {
302370
usage(argv[0]);
303371
}
304-
} else if (strcmp(argv[arg_idx], "--attention-qkv-type") == 0) {
305-
if (arg_idx < argc-1) {
306-
params.attn_qkv_tensor_type = parse_ggml_type(argv[++arg_idx]);
307-
if (params.attn_qkv_tensor_type == GGML_TYPE_COUNT) {
308-
usage(argv[0]);
309-
}
310-
} else {
311-
usage(argv[0]);
312-
}
313-
} else if (strcmp(argv[arg_idx], "--attention-q-type") == 0) {
314-
if (arg_idx < argc-1) {
315-
params.attn_q_tensor_type = parse_ggml_type(argv[++arg_idx]);
316-
if (params.attn_q_tensor_type == GGML_TYPE_COUNT) {
317-
usage(argv[0]);
318-
}
319-
} else {
320-
usage(argv[0]);
321-
}
322-
} else if (strcmp(argv[arg_idx], "--attention-k-type") == 0) {
323-
if (arg_idx < argc-1) {
324-
params.attn_k_tensor_type = parse_ggml_type(argv[++arg_idx]);
325-
if (params.attn_k_tensor_type == GGML_TYPE_COUNT) {
326-
usage(argv[0]);
327-
}
328-
} else {
329-
usage(argv[0]);
330-
}
331-
} else if (strcmp(argv[arg_idx], "--attention-v-type") == 0) {
332-
if (arg_idx < argc-1) {
333-
params.attn_v_tensor_type = parse_ggml_type(argv[++arg_idx]);
334-
if (params.attn_v_tensor_type == GGML_TYPE_COUNT) {
335-
usage(argv[0]);
336-
}
337-
} else {
338-
usage(argv[0]);
339-
}
340-
} else if (strcmp(argv[arg_idx], "--attention-qa-type") == 0) {
341-
if (arg_idx < argc-1) {
342-
params.attn_qa_tensor_type = parse_ggml_type(argv[++arg_idx]);
343-
if (params.attn_qa_tensor_type == GGML_TYPE_COUNT) {
344-
usage(argv[0]);
345-
}
346-
} else {
347-
usage(argv[0]);
348-
}
349-
} else if (strcmp(argv[arg_idx], "--attention-qb-type") == 0) {
350-
if (arg_idx < argc-1) {
351-
params.attn_qb_tensor_type = parse_ggml_type(argv[++arg_idx]);
352-
if (params.attn_qb_tensor_type == GGML_TYPE_COUNT) {
353-
usage(argv[0]);
354-
}
355-
} else {
356-
usage(argv[0]);
357-
}
358-
} else if (strcmp(argv[arg_idx], "--attention-kva-type") == 0) {
359-
if (arg_idx < argc-1) {
360-
params.attn_kva_tensor_type = parse_ggml_type(argv[++arg_idx]);
361-
if (params.attn_kva_tensor_type == GGML_TYPE_COUNT) {
362-
usage(argv[0]);
363-
}
364-
} else {
365-
usage(argv[0]);
366-
}
367-
} else if (strcmp(argv[arg_idx], "--attention-kvb-type") == 0) {
368-
if (arg_idx < argc-1) {
369-
params.attn_kvb_tensor_type = parse_ggml_type(argv[++arg_idx]);
370-
if (params.attn_kvb_tensor_type == GGML_TYPE_COUNT) {
371-
usage(argv[0]);
372-
}
373-
} else {
374-
usage(argv[0]);
375-
}
376-
} else if (strcmp(argv[arg_idx], "--attention-output-type") == 0) {
377-
if (arg_idx < argc-1) {
378-
params.attn_output_tensor_type = parse_ggml_type(argv[++arg_idx]);
379-
if (params.attn_output_tensor_type == GGML_TYPE_COUNT) {
380-
usage(argv[0]);
381-
}
382-
} else {
383-
usage(argv[0]);
384-
}
385-
} else if (strcmp(argv[arg_idx], "--feedforward-up-type") == 0) {
386-
if (arg_idx < argc-1) {
387-
params.ffn_up_tensor_type = parse_ggml_type(argv[++arg_idx]);
388-
if (params.ffn_up_tensor_type == GGML_TYPE_COUNT) {
389-
usage(argv[0]);
390-
}
391-
} else {
392-
usage(argv[0]);
393-
}
394-
} else if (strcmp(argv[arg_idx], "--feedforward-gate-type") == 0) {
395-
if (arg_idx < argc-1) {
396-
params.ffn_gate_tensor_type = parse_ggml_type(argv[++arg_idx]);
397-
if (params.ffn_gate_tensor_type == GGML_TYPE_COUNT) {
398-
usage(argv[0]);
399-
}
400-
} else {
401-
usage(argv[0]);
402-
}
403-
} else if (strcmp(argv[arg_idx], "--feedforward-down-type") == 0) {
404-
if (arg_idx < argc-1) {
405-
params.ffn_down_tensor_type = parse_ggml_type(argv[++arg_idx]);
406-
if (params.ffn_down_tensor_type == GGML_TYPE_COUNT) {
407-
usage(argv[0]);
408-
}
409-
} else {
410-
usage(argv[0]);
411-
}
412-
} else if (strcmp(argv[arg_idx], "--feedforward-up-exp-type") == 0) {
413-
if (arg_idx < argc-1) {
414-
params.ffn_up_exp_tensor_type = parse_ggml_type(argv[++arg_idx]);
415-
if (params.ffn_up_exp_tensor_type == GGML_TYPE_COUNT) {
416-
usage(argv[0]);
417-
}
418-
} else {
419-
usage(argv[0]);
420-
}
421-
} else if (strcmp(argv[arg_idx], "--feedforward-gate-exp-type") == 0) {
422-
if (arg_idx < argc-1) {
423-
params.ffn_gate_exp_tensor_type = parse_ggml_type(argv[++arg_idx]);
424-
if (params.ffn_gate_exp_tensor_type == GGML_TYPE_COUNT) {
425-
usage(argv[0]);
426-
}
427-
} else {
428-
usage(argv[0]);
429-
}
430-
} else if (strcmp(argv[arg_idx], "--feedforward-down-exp-type") == 0) {
431-
if (arg_idx < argc-1) {
432-
params.ffn_down_exp_tensor_type = parse_ggml_type(argv[++arg_idx]);
433-
if (params.ffn_down_exp_tensor_type == GGML_TYPE_COUNT) {
434-
usage(argv[0]);
435-
}
436-
} else {
437-
usage(argv[0]);
438-
}
439-
} else if (strcmp(argv[arg_idx], "--feedforward-up-shexp_type") == 0) {
440-
if (arg_idx < argc-1) {
441-
params.ffn_up_shexp_tensor_type = parse_ggml_type(argv[++arg_idx]);
442-
if (params.ffn_up_shexp_tensor_type == GGML_TYPE_COUNT) {
443-
usage(argv[0]);
444-
}
445-
} else {
446-
usage(argv[0]);
447-
}
448-
} else if (strcmp(argv[arg_idx], "--feedforward-gate-shexp-type") == 0) {
449-
if (arg_idx < argc-1) {
450-
params.ffn_gate_shexp_tensor_type = parse_ggml_type(argv[++arg_idx]);
451-
if (params.ffn_gate_shexp_tensor_type == GGML_TYPE_COUNT) {
452-
usage(argv[0]);
453-
}
454-
} else {
455-
usage(argv[0]);
456-
}
457-
} else if (strcmp(argv[arg_idx], "--feedforward-down-shexp-type") == 0) {
458-
if (arg_idx < argc-1) {
459-
params.ffn_down_shexp_tensor_type = parse_ggml_type(argv[++arg_idx]);
460-
if (params.ffn_down_shexp_tensor_type == GGML_TYPE_COUNT) {
461-
usage(argv[0]);
462-
}
463-
} else {
464-
usage(argv[0]);
465-
}
466-
} else if (strcmp(argv[arg_idx], "--classifier-type") == 0) {
467-
if (arg_idx < argc-1) {
468-
params.cls_tensor_type = parse_ggml_type(argv[++arg_idx]);
469-
if (params.cls_tensor_type == GGML_TYPE_COUNT) {
470-
usage(argv[0]);
471-
}
472-
} else {
473-
usage(argv[0]);
474-
}
475-
} else if (strcmp(argv[arg_idx], "--classifier-output-type") == 0) {
476-
if (arg_idx < argc-1) {
477-
params.cls_output_tensor_type = parse_ggml_type(argv[++arg_idx]);
478-
if (params.cls_output_tensor_type == GGML_TYPE_COUNT) {
479-
usage(argv[0]);
480-
}
481-
} else {
372+
} else if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
373+
if (arg_idx == argc-1 || !string_parse_tensor_type(argv[++arg_idx], tensor_types)) {
482374
usage(argv[0]);
483375
}
484376
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
@@ -565,6 +457,9 @@ int main(int argc, char ** argv) {
565457
kv_overrides.back().key[0] = 0;
566458
params.kv_overrides = &kv_overrides;
567459
}
460+
if (!tensor_types.empty()) {
461+
params.tensor_types = &tensor_types;
462+
}
568463

569464
llama_backend_init();
570465

include/llama.h

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -366,27 +366,7 @@ extern "C" {
366366
bool keep_split; // quantize to the same number of shards
367367
void * imatrix; // pointer to importance matrix data
368368
void * kv_overrides; // pointer to vector containing overrides
369-
ggml_type attn_qkv_tensor_type; // attention query/key/value tensor type
370-
ggml_type attn_q_tensor_type; // attention query tensor type
371-
ggml_type attn_k_tensor_type; // attention key tensor type
372-
ggml_type attn_v_tensor_type; // attention value tensor type
373-
ggml_type attn_qa_tensor_type; // attention query a tensor type
374-
ggml_type attn_qb_tensor_type; // attention query b tensor type
375-
ggml_type attn_kva_tensor_type; // attention key/value a tensor type
376-
ggml_type attn_kvb_tensor_type; // attention key/value b tensor type
377-
ggml_type attn_output_tensor_type; // attention output tensor type
378-
ggml_type ffn_up_tensor_type; // feedforward up tensor type
379-
ggml_type ffn_gate_tensor_type; // feedforward gate tensor type
380-
ggml_type ffn_down_tensor_type; // feedforward down tensor type
381-
ggml_type ffn_up_exp_tensor_type; // feedforward up expert tensor type
382-
ggml_type ffn_gate_exp_tensor_type; // feedforward gate expert tensor type
383-
ggml_type ffn_down_exp_tensor_type; // feedforward down expert tensor type
384-
ggml_type ffn_up_shexp_tensor_type; // feedforward up shared expert tensor type
385-
ggml_type ffn_gate_shexp_tensor_type; // feedforward gate shared expert tensor type
386-
ggml_type ffn_down_shexp_tensor_type; // feedforward down shared expert tensor type
387-
ggml_type cls_tensor_type; // classifier tensor type
388-
ggml_type cls_output_tensor_type; // classifier output tensor type
389-
369+
void * tensor_types; // pointer to vector containing tensor types
390370
} llama_model_quantize_params;
391371

392372
typedef struct llama_logit_bias {

0 commit comments

Comments
 (0)