@@ -250,6 +250,7 @@ struct cmd_params {
250250    std::vector<bool >                cpu_strict;
251251    std::vector<int >                 poll;
252252    std::vector<int >                 n_gpu_layers;
253+     std::vector<int >                 n_cpu_moe;
253254    std::vector<std::string>         rpc_servers;
254255    std::vector<llama_split_mode>    split_mode;
255256    std::vector<int >                 main_gpu;
@@ -286,6 +287,7 @@ static const cmd_params cmd_params_defaults = {
286287    /*  cpu_strict           */   { false  },
287288    /*  poll                 */   { 50  },
288289    /*  n_gpu_layers         */   { 99  },
290+     /*  n_cpu_moe            */   { 0  },
289291    /*  rpc_servers          */   { " "   },
290292    /*  split_mode           */   { LLAMA_SPLIT_MODE_LAYER },
291293    /*  main_gpu             */   { 0  },
@@ -353,6 +355,8 @@ static void print_usage(int /* argc */, char ** argv) {
353355    printf ("   --poll <0...100>                          (default: %s)\n "  , join (cmd_params_defaults.poll , " ,"  ).c_str ());
354356    printf ("   -ngl, --n-gpu-layers <n>                  (default: %s)\n "  ,
355357           join (cmd_params_defaults.n_gpu_layers , " ,"  ).c_str ());
358+     printf ("   -ncmoe, --n-cpu-moe <n>                   (default: %s)\n "  ,
359+            join (cmd_params_defaults.n_cpu_moe , " ,"  ).c_str ());
356360    if  (llama_supports_rpc ()) {
357361        printf ("   -rpc, --rpc <rpc_servers>                 (default: %s)\n "  ,
358362               join (cmd_params_defaults.rpc_servers , " ,"  ).c_str ());
@@ -564,6 +568,45 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
564568                }
565569                auto  p = parse_int_range (argv[i]);
566570                params.n_gpu_layers .insert (params.n_gpu_layers .end (), p.begin (), p.end ());
571+             } else  if  (arg == " -ncmoe"   || arg == " --n-cpu-moe"  ) {
572+                 if  (++i >= argc) {
573+                     invalid_param = true ;
574+                     break ;
575+                 }
576+ 
577+                 const  auto  values = parse_int_range (argv[i]);
578+                 if  (values.size () != 1 ) {
579+                     invalid_param = true ;
580+                     break ;
581+                 }
582+ 
583+                 const  int  n_layers = values[0 ];
584+                 if  (n_layers < 0 ) {
585+                     invalid_param = true ;
586+                     break ;
587+                 }
588+ 
589+                 if  (n_layers > 0 ) {
590+                     static  std::vector<std::vector<std::string> > buft_batches;
591+                     buft_batches.emplace_back ();
592+                     std::vector<std::string> & batch = buft_batches.back ();
593+                     batch.reserve (static_cast <size_t >(n_layers));
594+ 
595+                     std::vector<llama_model_tensor_buft_override> group_tensor_buft_overrides;
596+                     group_tensor_buft_overrides.reserve (static_cast <size_t >(n_layers) + 1 );
597+ 
598+                     for  (int  i = 0 ; i < n_layers; ++i) {
599+                         batch.push_back (llm_ffn_exps_block_regex (i));
600+                         const  char  * pattern = batch.back ().c_str ();
601+                         group_tensor_buft_overrides.push_back ({
602+                             pattern,
603+                             ggml_backend_cpu_buffer_type ()
604+                         });
605+                     }
606+ 
607+                     group_tensor_buft_overrides.push_back ({ nullptr , nullptr  });
608+                     params.tensor_buft_overrides .push_back (std::move (group_tensor_buft_overrides));
609+                 }
567610            } else  if  (llama_supports_rpc () && (arg == " -rpc"   || arg == " --rpc"  )) {
568611                if  (++i >= argc) {
569612                    invalid_param = true ;
@@ -1514,6 +1557,9 @@ struct markdown_printer : public printer {
15141557        if  (field == " no_op_offload"  ) {
15151558            return  4 ;
15161559        }
1560+         if  (field == " tensor_buft_overrides"  ) {
1561+             return  40 ;
1562+         }
15171563
15181564        int  width = std::max ((int ) field.length (), 10 );
15191565
@@ -1684,6 +1730,12 @@ struct markdown_printer : public printer {
16841730            }
16851731
16861732            int  width = get_field_width (field);
1733+ 
1734+             if  (field == " tensor_buft_overrides"  ) {
1735+                 if  (value.size () > width)
1736+                     value.erase (width);
1737+             }
1738+ 
16871739            if  (field == " t/s"  ) {
16881740                //  HACK: the utf-8 character is 2 bytes
16891741                width += 1 ;
0 commit comments