@@ -283,7 +283,7 @@ static const cmd_params cmd_params_defaults = {
283283 /* type_k */ {GGML_TYPE_F16},
284284 /* type_v */ {GGML_TYPE_F16},
285285 /* n_threads */ {{cpu_get_num_math (), cpu_get_num_math ()}},
286- /* n_gpu_layers */ {99 },
286+ /* n_gpu_layers */ {999 },
287287 /* rpc_servers */ {" " },
288288 /* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
289289 /* main_gpu */ {0 },
@@ -330,6 +330,7 @@ static void print_usage(int /* argc */, char ** argv) {
330330 printf (" -t, --threads <n> (default: %s)\n " , join (cmd_params_defaults.n_threads , " ," ).c_str ());
331331 printf (" -tgb, --threads-gen-batch <n1,n2> (default: %s)\n " , join (cmd_params_defaults.n_threads , " ," ).c_str ());
332332 printf (" -ngl, --n-gpu-layers <n> (default: %s)\n " , join (cmd_params_defaults.n_gpu_layers , " ," ).c_str ());
333+ printf (" --n-cpu-moe <n> (default: none)\n " );
333334 printf (" -rpc, --rpc <rpc_servers> (default: %s)\n " , join (cmd_params_defaults.rpc_servers , " ," ).c_str ());
334335 printf (" -sm, --split-mode <none|layer|row> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.split_mode , split_mode_str), " ," ).c_str ());
335336 printf (" -mg, --main-gpu <i> (default: %s)\n " , join (cmd_params_defaults.main_gpu , " ," ).c_str ());
@@ -428,6 +429,19 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens
428429 }
429430 return true ;
430431}
432+ bool add_cpu_buft_overrides (const char * arg, std::vector<llama_model_tensor_buft_override>& overrides) {
433+ int n_layers = std::stoi (arg);
434+ if (n_layers < 0 ) {
435+ fprintf (stderr, " error: Invalid value for --n-cpu-moe: %s\n " , arg);
436+ return false ;
437+ }
438+ for (int32_t l = 0 ; l < n_layers; ++l) {
439+ std::string pattern = " blk\\ ." + std::to_string (l) + " \\ .(ffn_(up|down|gate)_exps\\ .weight)" ;
440+ overrides.push_back ({strdup (pattern.c_str ()), ggml_backend_cpu_buffer_type ()});
441+ }
442+ return true ;
443+ }
444+
431445template <class T1 , class T2 >
432446std::vector<std::pair<T1,T2>> string_split_pairs (const std::string & str, char delim) {
433447 std::vector<std::pair<T1,T2>> values;
@@ -800,6 +814,15 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
800814 invalid_param = true ;
801815 break ;
802816 }
817+ } else if (arg == " --n-cpu-moe" ) {
818+ if (++i >= argc) {
819+ invalid_param = true ;
820+ break ;
821+ }
822+ if (!add_cpu_buft_overrides (argv[i], params.buft_overrides )) {
823+ invalid_param = true ;
824+ break ;
825+ }
803826 } else {
804827 invalid_param = true ;
805828 break ;
0 commit comments