99#include  < fstream> 
1010#include  < cmath> 
1111#include  < cctype> 
12+ #include  < algorithm> 
1213
1314struct  quant_option  {
1415    std::string name;
1516    llama_ftype ftype;
1617    std::string desc;
1718};
1819
19- static  const  std::vector<struct   quant_option > QUANT_OPTIONS = {
20+ static  const  std::vector<quant_option> QUANT_OPTIONS = {
2021    { " Q4_0"  ,     LLAMA_FTYPE_MOSTLY_Q4_0,     "  4.34G, +0.4685 ppl @ Llama-3-8B"  ,  },
2122    { " Q4_1"  ,     LLAMA_FTYPE_MOSTLY_Q4_1,     "  4.78G, +0.4511 ppl @ Llama-3-8B"  ,  },
2223    { " Q5_0"  ,     LLAMA_FTYPE_MOSTLY_Q5_0,     "  5.21G, +0.1316 ppl @ Llama-3-8B"  ,  },
@@ -105,7 +106,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
105106// 
106107[[noreturn]]
107108static  void  usage (const  char  * executable) {
108-     printf (" usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n "  , executable);
109+     printf (" usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n "  , executable);
110+     printf ("        [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n "  );
109111    printf ("   --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n "  );
110112    printf ("   --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n "  );
111113    printf ("   --pure: Disable k-quant mixtures and quantize all tensors to the same type\n "  );
@@ -114,6 +116,8 @@ static void usage(const char * executable) {
114116    printf ("   --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n "  );
115117    printf ("   --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n "  );
116118    printf ("   --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n "  );
119+     printf ("   --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n "  );
120+     printf ("       Advanced option to selectively quantize tensors. May be specified multiple times.\n "  );
117121    printf ("   --keep-split: will generate quantized model in the same shards as input\n "  );
118122    printf ("   --override-kv KEY=TYPE:VALUE\n "  );
119123    printf ("       Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n "  );
@@ -244,6 +248,107 @@ static ggml_type parse_ggml_type(const char * arg) {
244248    return  GGML_TYPE_COUNT;
245249}
246250
251+ //  Allowed tensors for arbitrary quantization with --tensor-type option
252+ static  const  std::vector<std::string> ALLOWED_TENSOR_TYPE = {
253+     " attn_k"  ,
254+     " attn_kv_a_mqa"  ,
255+     " attn_kv_b"  ,
256+     " attn_o"  ,
257+     " attn_output"  ,
258+     " attn_q"  ,
259+     " attn_q_a"  ,
260+     " attn_q_b"  ,
261+     " attn_qkv"  ,
262+     " attn_v"  ,
263+     " channel_mix_key"  ,
264+     " channel_mix_receptance"  ,
265+     " channel_mix_value"  ,
266+     " cls"  ,
267+     " cls.output"  ,
268+     " cross_attn_k"  ,
269+     " cross_attn_o"  ,
270+     " cross_attn_q"  ,
271+     " cross_attn_v"  ,
272+     " ffn_act"  ,
273+     " ffn_down"  ,
274+     " ffn_down_exps"  ,
275+     " ffn_down_shexp"  ,
276+     " ffn_gate"  ,
277+     " ffn_gate_exps"  ,
278+     " ffn_gate_shexp"  ,
279+     " ffn_up"  ,
280+     " ffn_up_exps"  ,
281+     " ffn_up_shexp"  ,
282+     " ssm_in"  ,
283+     " ssm_out"  ,
284+     " time_mix_gate"  ,
285+     " time_mix_key"  ,
286+     " time_mix_output"  ,
287+     " time_mix_receptance"  ,
288+     " time_mix_value"  ,
289+ };
290+ 
291+ //  changes to this struct must be replicated in llama-quant.cpp
292+ struct  tensor_quantization  {
293+     std::string name;
294+     ggml_type quant = GGML_TYPE_COUNT;
295+ };
296+ 
297+ static  bool  parse_tensor_type (const  char  * data, std::vector<tensor_quantization> & tensor_type) {
298+     const  char  * sep = strchr (data, ' ='  );
299+     if  (sep == nullptr ) {
300+         printf (" \n %s: malformed tensor type '%s'\n\n "  , __func__, data);
301+         return  false ;
302+     }
303+ 
304+     const  size_t  tn_len = sep - data;
305+     if  (tn_len == 0 ) {
306+         printf (" \n %s: missing tensor name\n\n "  , __func__);
307+         return  false ;
308+     }
309+ 
310+     if  (const  size_t  qt_len = strlen (sep); qt_len == 1 ) {
311+         printf (" \n %s: missing quantization type\n\n "  , __func__);
312+         return  false ;
313+     }
314+ 
315+     std::string tn (data, tn_len);
316+     std::transform (tn.begin (), tn.end (), tn.begin (), tolower);
317+     sep++;
318+     const  std::string qt (sep);
319+ 
320+     bool  found = false ;
321+     for  (const  auto  & allowed : ALLOWED_TENSOR_TYPE) {
322+         std::string tensor;
323+         tensor = tn.rfind (' .'  ) != std::string::npos ? tn.substr (tn.rfind (' .'  ) + 1 ) : tn;
324+         //  handle special case of cls.output
325+         std::string cls_output = " cls.output"  ;
326+         if  (tn.find (cls_output) != std::string::npos) {
327+             tensor = " cls.output"  ;
328+         }
329+         //  check if an allowed tensor exists and it's at the end of the kv string
330+         if  (tensor == allowed) {
331+             found = true ;
332+             break ;
333+         }
334+     }
335+     if  (!found) {
336+         printf (" \n %s: invalid tensor name '%s'\n\n "  , __func__, tn.c_str ());
337+         return  false ;
338+     }
339+ 
340+     if  (parse_ggml_type (qt.c_str ()) == GGML_TYPE_COUNT) {
341+         printf (" \n %s: invalid quantization type '%s'\n\n "  , __func__, qt.c_str ());
342+         return  false ;
343+     }
344+ 
345+     tensor_quantization tqz;
346+     tqz.name  = tn;
347+     tqz.quant  = parse_ggml_type (qt.c_str ());
348+     tensor_type.emplace_back (std::move (tqz));
349+     return  true ;
350+ }
351+ 
247352int  main (int  argc, char  ** argv) {
248353    if  (argc < 3 ) {
249354        usage (argv[0 ]);
@@ -255,6 +360,7 @@ int main(int argc, char ** argv) {
255360    std::string imatrix_file;
256361    std::vector<std::string> included_weights, excluded_weights;
257362    std::vector<llama_model_kv_override> kv_overrides;
363+     std::vector<tensor_quantization> tensor_types;
258364
259365    for  (; arg_idx < argc && strncmp (argv[arg_idx], " --"  , 2 ) == 0 ; arg_idx++) {
260366        if  (strcmp (argv[arg_idx], " --leave-output-tensor"  ) == 0 ) {
@@ -277,6 +383,10 @@ int main(int argc, char ** argv) {
277383            } else  {
278384                usage (argv[0 ]);
279385            }
386+         } else  if  (strcmp (argv[arg_idx], " --tensor-type"  ) == 0 ) {
387+             if  (arg_idx == argc-1  || !parse_tensor_type (argv[++arg_idx], tensor_types)) {
388+                 usage (argv[0 ]);
389+             }
280390        } else  if  (strcmp (argv[arg_idx], " --override-kv"  ) == 0 ) {
281391            if  (arg_idx == argc-1  || !string_parse_kv_override (argv[++arg_idx], kv_overrides)) {
282392                usage (argv[0 ]);
@@ -361,6 +471,9 @@ int main(int argc, char ** argv) {
361471        kv_overrides.back ().key [0 ] = 0 ;
362472        params.kv_overrides  = &kv_overrides;
363473    }
474+     if  (!tensor_types.empty ()) {
475+         params.tensor_types  = &tensor_types;
476+     }
364477
365478    llama_backend_init ();
366479
0 commit comments