@@ -264,6 +264,7 @@ struct cmd_params {
264264 bool ger = false ; // ger = Grouped Expert Routing
265265 bool no_fug = false ;
266266 bool use_thp = false ;
267+ bool no_ooae = false ;
267268 output_formats output_format;
268269 output_formats output_format_stderr;
269270};
@@ -301,6 +302,7 @@ static const cmd_params cmd_params_defaults = {
301302 /* ger */ false ,
302303 /* no_fug */ false ,
303304 /* use_thp */ false ,
305+ /* no_ooae */ false ,
304306 /* output_format */ MARKDOWN,
305307 /* output_format_stderr */ NONE,
306308};
@@ -345,6 +347,7 @@ static void print_usage(int /* argc */, char ** argv) {
345347 printf (" -fmoe, --fused-moe <0|1> (default: %s)\n " , cmd_params_defaults.fmoe ? " 1" : " 0" );
346348 printf (" -ger, --grouped-expert-routing <0|1>(default: %s)\n " , cmd_params_defaults.ger ? " 1" : " 0" );
347349 printf (" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n " , cmd_params_defaults.no_fug ? " 1" : " 0" );
350+ printf (" -no-ooae, --no-offload-only-active-experts <0|1> (default: %s)\n " , cmd_params_defaults.no_ooae ? " 1" : " 0" );
348351 printf (" \n " );
349352 printf (" Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n " );
350353}
@@ -754,6 +757,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
754757 break ;
755758 }
756759 params.no_fug = std::stoi (argv[i]);
760+ } else if (arg == " -no-ooae" || arg == " --no-offload-only-active-experts" ) {
761+ if (++i >= argc) {
762+ invalid_param = true ;
763+ break ;
764+ }
765+ params.no_ooae = std::stoi (argv[i]);
757766 } else if (arg == " -ot" || arg == " --override-tensor" ) {
758767 if (++i >= argc) {
759768 invalid_param = true ;
@@ -841,6 +850,7 @@ struct cmd_params_instance {
841850 bool ger = false ;
842851 bool no_fug = false ;
843852 bool use_thp = false ;
853+ bool no_ooae = false ;
844854 const llama_model_tensor_buft_override* buft_overrides;
845855
846856 llama_model_params to_llama_mparams () const {
@@ -888,6 +898,7 @@ struct cmd_params_instance {
888898 cparams.fused_moe_up_gate = fmoe;
889899 cparams.grouped_expert_routing = ger;
890900 cparams.fused_up_gate = !no_fug;
901+ cparams.only_active_experts = !no_ooae;
891902 cparams.min_experts = ser.first ;
892903 cparams.thresh_experts = ser.second ;
893904 cparams.embeddings = embeddings;
@@ -949,6 +960,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
949960 /* .ger = */ params.ger ,
950961 /* .no_fug = */ params.no_fug ,
951962 /* .use_thp = */ params.use_thp ,
963+ /* .no_ooae = */ params.no_ooae ,
952964 /* .buft_overrides=*/ params.buft_overrides .data (),
953965 };
954966 instances.push_back (instance);
@@ -985,6 +997,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
985997 /* .ger = */ params.ger ,
986998 /* .no_fug = */ params.no_fug ,
987999 /* .use_thp = */ params.use_thp ,
1000+ /* .no_ooae = */ params.no_ooae ,
9881001 /* .buft_overrides=*/ params.buft_overrides .data (),
9891002 };
9901003 instances.push_back (instance);
@@ -1021,6 +1034,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10211034 /* .ger = */ params.ger ,
10221035 /* .no_fug = */ params.no_fug ,
10231036 /* .use_thp = */ params.use_thp ,
1037+ /* .no_ooae = */ params.no_ooae ,
10241038 /* .buft_overrides=*/ params.buft_overrides .data (),
10251039 };
10261040 instances.push_back (instance);
@@ -1057,6 +1071,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10571071 /* .ger = */ params.ger ,
10581072 /* .no_fug = */ params.no_fug ,
10591073 /* .use_thp = */ params.use_thp ,
1074+ /* .no_ooae = */ params.no_ooae ,
10601075 /* .buft_overrides=*/ params.buft_overrides .data (),
10611076 };
10621077 instances.push_back (instance);
@@ -1104,6 +1119,7 @@ struct test {
11041119 bool ger = false ;
11051120 bool no_fug = false ;
11061121 bool use_thp = false ;
1122+ bool no_ooae = false ;
11071123 int n_prompt;
11081124 int n_gen;
11091125 std::string test_time;
@@ -1140,6 +1156,7 @@ struct test {
11401156 ger = inst.ger ;
11411157 no_fug = inst.no_fug ;
11421158 use_thp = inst.use_thp ;
1159+ no_ooae = inst.no_ooae ;
11431160 n_prompt = inst.n_prompt ;
11441161 n_gen = inst.n_gen ;
11451162 test_kind = inst.test_kind ;
@@ -1230,7 +1247,7 @@ struct test {
12301247 " n_threads" , " type_k" , " type_v" ,
12311248 " n_gpu_layers" , " split_mode" ,
12321249 " main_gpu" , " no_kv_offload" , " flash_attn" , " mla_attn" , " attn_max_batch" , " ser" ,
1233- " tensor_split" , " use_mmap" , " embeddings" , " repack" , " fused_moe" , " grouped_er" , " fused_up_gate" , " use_thp" ,
1250+ " tensor_split" , " use_mmap" , " embeddings" , " repack" , " fused_moe" , " grouped_er" , " fused_up_gate" , " use_thp" , " ooae " ,
12341251 " n_prompt" , " n_gen" , " test_time" ,
12351252 " avg_ns" , " stddev_ns" ,
12361253 " avg_ts" , " stddev_ts" , " test" ,
@@ -1252,7 +1269,7 @@ struct test {
12521269 if (field == " cuda" || field == " vulkan" || field == " kompute" || field == " metal" ||
12531270 field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
12541271 field == " flash_attn" || field == " use_mmap" || field == " embeddings" || field == " repack" || field == " use_thp" ||
1255- field == " fused_moe" || field == " grouped_er" || field == " fused_up_gate" ) {
1272+ field == " fused_moe" || field == " grouped_er" || field == " fused_up_gate" || field == " ooae " ) {
12561273 return BOOL;
12571274 }
12581275 if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -1296,7 +1313,7 @@ struct test {
12961313 std::to_string (mla_attn), std::to_string (attn_max_batch), ser_to_string (ser),
12971314 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
12981315 std::to_string (repack), std::to_string (fmoe), std::to_string (ger),
1299- std::to_string (no_fug), std::to_string (use_thp),
1316+ std::to_string (no_fug), std::to_string (use_thp), std::to_string (no_ooae),
13001317 std::to_string (n_prompt), std::to_string (n_gen), test_time,
13011318 std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
13021319 std::to_string (avg_ts ()), std::to_string (stdev_ts ()),
@@ -1486,6 +1503,9 @@ struct markdown_printer : public printer {
14861503 if (field == " fused_up_gate" ) {
14871504 return 6 ;
14881505 }
1506+ if (field == " ooae" ) {
1507+ return 7 ;
1508+ }
14891509 if (field == " test" ) {
14901510 return 13 ;
14911511 }
@@ -1544,6 +1564,9 @@ struct markdown_printer : public printer {
15441564 if (field == " fused_up_gate" ) {
15451565 return " no-fug" ;
15461566 }
1567+ if (field == " ooae" ) {
1568+ return " no-ooae" ;
1569+ }
15471570 if (field == " embeddings" ) {
15481571 return " embd" ;
15491572 }
@@ -1623,6 +1646,9 @@ struct markdown_printer : public printer {
16231646 if (params.no_fug != cmd_params_defaults.no_fug ) {
16241647 fields.emplace_back (" fused_up_gate" );
16251648 }
1649+ if (params.no_ooae != cmd_params_defaults.no_ooae ) {
1650+ fields.emplace_back (" ooae" );
1651+ }
16261652 fields.emplace_back (" test" );
16271653 fields.emplace_back (" t/s" );
16281654
0 commit comments