@@ -250,6 +250,7 @@ struct cmd_params {
250250 std::vector<bool > cpu_strict;
251251 std::vector<int > poll;
252252 std::vector<int > n_gpu_layers;
253+ std::vector<int > n_cpu_moe;
253254 std::vector<std::string> rpc_servers;
254255 std::vector<llama_split_mode> split_mode;
255256 std::vector<int > main_gpu;
@@ -286,6 +287,7 @@ static const cmd_params cmd_params_defaults = {
286287 /* cpu_strict */ { false },
287288 /* poll */ { 50 },
288289 /* n_gpu_layers */ { 99 },
290+ /* n_cpu_moe */ { 0 },
289291 /* rpc_servers */ { " " },
290292 /* split_mode */ { LLAMA_SPLIT_MODE_LAYER },
291293 /* main_gpu */ { 0 },
@@ -353,6 +355,8 @@ static void print_usage(int /* argc */, char ** argv) {
353355 printf (" --poll <0...100> (default: %s)\n " , join (cmd_params_defaults.poll , " ," ).c_str ());
354356 printf (" -ngl, --n-gpu-layers <n> (default: %s)\n " ,
355357 join (cmd_params_defaults.n_gpu_layers , " ," ).c_str ());
358+ printf (" -ncmoe, --n-cpu-moe <n> (default: %s)\n " ,
359+ join (cmd_params_defaults.n_cpu_moe , " ," ).c_str ());
356360 if (llama_supports_rpc ()) {
357361 printf (" -rpc, --rpc <rpc_servers> (default: %s)\n " ,
358362 join (cmd_params_defaults.rpc_servers , " ," ).c_str ());
@@ -564,6 +568,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
564568 }
565569 auto p = parse_int_range (argv[i]);
566570 params.n_gpu_layers .insert (params.n_gpu_layers .end (), p.begin (), p.end ());
571+ } else if (arg == " -ncmoe" || arg == " --n-cpu-moe" ) {
572+ if (++i >= argc) {
573+ invalid_param = true ;
574+ break ;
575+ }
576+ auto p = parse_int_range (argv[i]);
577+ params.n_cpu_moe .insert (params.n_cpu_moe .end (), p.begin (), p.end ());
567578 } else if (llama_supports_rpc () && (arg == " -rpc" || arg == " --rpc" )) {
568579 if (++i >= argc) {
569580 invalid_param = true ;
@@ -841,6 +852,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
841852 if (params.n_gpu_layers .empty ()) {
842853 params.n_gpu_layers = cmd_params_defaults.n_gpu_layers ;
843854 }
855+ if (params.n_cpu_moe .empty ()) {
856+ params.n_cpu_moe = cmd_params_defaults.n_cpu_moe ;
857+ }
844858 if (params.rpc_servers .empty ()) {
845859 params.rpc_servers = cmd_params_defaults.rpc_servers ;
846860 }
@@ -901,6 +915,7 @@ struct cmd_params_instance {
901915 bool cpu_strict;
902916 int poll;
903917 int n_gpu_layers;
918+ int n_cpu_moe;
904919 std::string rpc_servers_str;
905920 llama_split_mode split_mode;
906921 int main_gpu;
@@ -973,20 +988,50 @@ struct cmd_params_instance {
973988 mparams.tensor_split = tensor_split.data ();
974989 mparams.use_mmap = use_mmap;
975990
976- if (tensor_buft_overrides.empty ()) {
977- mparams.tensor_buft_overrides = nullptr ;
991+ if (n_cpu_moe <= 0 ) {
992+ if (tensor_buft_overrides.empty ()) {
993+ mparams.tensor_buft_overrides = nullptr ;
994+ } else {
995+ GGML_ASSERT (tensor_buft_overrides.back ().pattern == nullptr &&
996+ " Tensor buffer overrides not terminated with empty pattern" );
997+ mparams.tensor_buft_overrides = tensor_buft_overrides.data ();
998+ }
978999 } else {
979- GGML_ASSERT (tensor_buft_overrides.back ().pattern == nullptr && " Tensor buffer overrides not terminated with empty pattern" );
980- mparams.tensor_buft_overrides = tensor_buft_overrides.data ();
1000+ static std::vector<llama_model_tensor_buft_override> merged;
1001+ static std::vector<std::string> patterns;
1002+
1003+ merged.clear ();
1004+ patterns.clear ();
1005+
1006+ auto first = tensor_buft_overrides.begin ();
1007+ auto last = tensor_buft_overrides.end ();
1008+ if (first != last && (last - 1 )->pattern == nullptr ) {
1009+ --last;
1010+ }
1011+ merged.insert (merged.end (), first, last);
1012+
1013+ patterns.reserve ((size_t ) n_cpu_moe);
1014+ merged.reserve (merged.size () + (size_t ) n_cpu_moe + 1 );
1015+
1016+ for (int i = 0 ; i < n_cpu_moe; ++i) {
1017+ patterns.push_back (llm_ffn_exps_block_regex (i));
1018+ merged.push_back ({ patterns.back ().c_str (),
1019+ ggml_backend_cpu_buffer_type () });
1020+ }
1021+
1022+ merged.push_back ({ nullptr , nullptr });
1023+
1024+ mparams.tensor_buft_overrides = merged.data ();
9811025 }
9821026
9831027 return mparams;
9841028 }
9851029
9861030 bool equal_mparams (const cmd_params_instance & other) const {
987- return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
988- split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
989- tensor_split == other.tensor_split && vec_tensor_buft_override_equal (tensor_buft_overrides, other.tensor_buft_overrides );
1031+ return model == other.model && n_gpu_layers == other.n_gpu_layers && n_cpu_moe == other.n_cpu_moe &&
1032+ rpc_servers_str == other.rpc_servers_str && split_mode == other.split_mode &&
1033+ main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split &&
1034+ vec_tensor_buft_override_equal (tensor_buft_overrides, other.tensor_buft_overrides );
9901035 }
9911036
9921037 llama_context_params to_llama_cparams () const {
@@ -1014,6 +1059,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10141059 // clang-format off
10151060 for (const auto & m : params.model )
10161061 for (const auto & nl : params.n_gpu_layers )
1062+ for (const auto & ncmoe : params.n_cpu_moe )
10171063 for (const auto & rpc : params.rpc_servers )
10181064 for (const auto & sm : params.split_mode )
10191065 for (const auto & mg : params.main_gpu )
@@ -1051,6 +1097,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10511097 /* .cpu_strict = */ cs,
10521098 /* .poll = */ pl,
10531099 /* .n_gpu_layers = */ nl,
1100+ /* .n_cpu_moe = */ ncmoe,
10541101 /* .rpc_servers = */ rpc,
10551102 /* .split_mode = */ sm,
10561103 /* .main_gpu = */ mg,
@@ -1083,6 +1130,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10831130 /* .cpu_strict = */ cs,
10841131 /* .poll = */ pl,
10851132 /* .n_gpu_layers = */ nl,
1133+ /* .n_cpu_moe = */ ncmoe,
10861134 /* .rpc_servers = */ rpc,
10871135 /* .split_mode = */ sm,
10881136 /* .main_gpu = */ mg,
@@ -1115,6 +1163,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
11151163 /* .cpu_strict = */ cs,
11161164 /* .poll = */ pl,
11171165 /* .n_gpu_layers = */ nl,
1166+ /* .n_cpu_moe = */ ncmoe,
11181167 /* .rpc_servers = */ rpc,
11191168 /* .split_mode = */ sm,
11201169 /* .main_gpu = */ mg,
@@ -1152,6 +1201,7 @@ struct test {
11521201 ggml_type type_k;
11531202 ggml_type type_v;
11541203 int n_gpu_layers;
1204+ int n_cpu_moe;
11551205 llama_split_mode split_mode;
11561206 int main_gpu;
11571207 bool no_kv_offload;
@@ -1186,6 +1236,7 @@ struct test {
11861236 type_k = inst.type_k ;
11871237 type_v = inst.type_v ;
11881238 n_gpu_layers = inst.n_gpu_layers ;
1239+ n_cpu_moe = inst.n_cpu_moe ;
11891240 split_mode = inst.split_mode ;
11901241 main_gpu = inst.main_gpu ;
11911242 no_kv_offload = inst.no_kv_offload ;
@@ -1236,12 +1287,14 @@ struct test {
12361287
12371288 static const std::vector<std::string> & get_fields () {
12381289 static const std::vector<std::string> fields = {
1239- " build_commit" , " build_number" , " cpu_info" , " gpu_info" , " backends" , " model_filename" ,
1240- " model_type" , " model_size" , " model_n_params" , " n_batch" , " n_ubatch" , " n_threads" ,
1241- " cpu_mask" , " cpu_strict" , " poll" , " type_k" , " type_v" , " n_gpu_layers" ,
1242- " split_mode" , " main_gpu" , " no_kv_offload" , " flash_attn" , " tensor_split" , " tensor_buft_overrides" ,
1243- " use_mmap" , " embeddings" , " no_op_offload" , " n_prompt" , " n_gen" , " n_depth" , " test_time" ,
1244- " avg_ns" , " stddev_ns" , " avg_ts" , " stddev_ts" ,
1290+ " build_commit" , " build_number" , " cpu_info" , " gpu_info" , " backends" ,
1291+ " model_filename" , " model_type" , " model_size" , " model_n_params" , " n_batch" ,
1292+ " n_ubatch" , " n_threads" , " cpu_mask" , " cpu_strict" , " poll" ,
1293+ " type_k" , " type_v" , " n_gpu_layers" , " n_cpu_moe" , " split_mode" ,
1294+ " main_gpu" , " no_kv_offload" , " flash_attn" , " tensor_split" , " tensor_buft_overrides" ,
1295+ " use_mmap" , " embeddings" , " no_op_offload" , " n_prompt" , " n_gen" ,
1296+ " n_depth" , " test_time" , " avg_ns" , " stddev_ns" , " avg_ts" ,
1297+ " stddev_ts"
12451298 };
12461299 return fields;
12471300 }
@@ -1251,8 +1304,8 @@ struct test {
12511304 static field_type get_field_type (const std::string & field) {
12521305 if (field == " build_number" || field == " n_batch" || field == " n_ubatch" || field == " n_threads" ||
12531306 field == " poll" || field == " model_size" || field == " model_n_params" || field == " n_gpu_layers" ||
1254- field == " main_gpu" || field == " n_prompt" || field == " n_gen" || field == " n_depth" ||
1255- field == " avg_ns " || field == " stddev_ns " || field == " no_op_offload " ) {
1307+ field == " main_gpu" || field == " n_prompt" || field == " n_gen" || field == " n_depth" || field == " avg_ns " ||
1308+ field == " stddev_ns " || field == " no_op_offload " || field == " n_cpu_moe " ) {
12561309 return INT;
12571310 }
12581311 if (field == " f16_kv" || field == " no_kv_offload" || field == " cpu_strict" || field == " flash_attn" ||
@@ -1320,6 +1373,7 @@ struct test {
13201373 ggml_type_name (type_k),
13211374 ggml_type_name (type_v),
13221375 std::to_string (n_gpu_layers),
1376+ std::to_string (n_cpu_moe),
13231377 split_mode_str (split_mode),
13241378 std::to_string (main_gpu),
13251379 std::to_string (no_kv_offload),
@@ -1568,6 +1622,9 @@ struct markdown_printer : public printer {
15681622 if (!is_cpu_backend) {
15691623 fields.emplace_back (" n_gpu_layers" );
15701624 }
1625+ if (params.n_cpu_moe .size () > 1 ) {
1626+ fields.emplace_back (" n_cpu_moe" );
1627+ }
15711628 if (params.n_threads .size () > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
15721629 fields.emplace_back (" n_threads" );
15731630 }
@@ -1683,7 +1740,8 @@ struct markdown_printer : public printer {
16831740 exit (1 );
16841741 }
16851742
1686- int width = get_field_width (field);
1743+ unsigned int width = get_field_width (field);
1744+
16871745 if (field == " t/s" ) {
16881746 // HACK: the utf-8 character is 2 bytes
16891747 width += 1 ;
0 commit comments