@@ -261,6 +261,7 @@ struct cmd_params {
261261 bool warmup;
262262 bool repack = false ;
263263 bool fmoe = false ;
264+ bool ger = false ; // ger = Grouped Expert Routing
264265 bool no_fug = false ;
265266 bool use_thp = false ;
266267 output_formats output_format;
@@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = {
296297 /* verbose */ false ,
297298 /* warmup */ true ,
298299 /* repack */ false ,
299- /* use_thp */ false ,
300300 /* fmoe */ false ,
301+ /* ger */ false ,
301302 /* no_fug */ false ,
303+ /* use_thp */ false ,
302304 /* output_format */ MARKDOWN,
303305 /* output_format_stderr */ NONE,
304306};
@@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) {
341343 printf (" -thp, --transparent-huge-pages <0|1> (default: %s)\n " , cmd_params_defaults.use_thp ? " 1" : " 0" );
342344 printf (" -ot, --override-tensor pattern (default: none)\n " );
343345 printf (" -fmoe, --fused-moe <0|1> (default: %s)\n " , cmd_params_defaults.fmoe ? " 1" : " 0" );
346+ printf (" -ger, --grouped-expert-routing <0|1>(default: %s)\n " , cmd_params_defaults.ger ? " 1" : " 0" );
344347 printf (" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n " , cmd_params_defaults.no_fug ? " 1" : " 0" );
345348 printf (" \n " );
346349 printf (" Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n " );
@@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
739742 break ;
740743 }
741744 params.fmoe = std::stoi (argv[i]);
745+ } else if (arg == " -ger" || arg == " --grouped-expert-routing" ) {
746+ if (++i >= argc) {
747+ invalid_param = true ;
748+ break ;
749+ }
750+ params.ger = std::stoi (argv[i]);
742751 } else if (arg == " -no-fug" || arg == " --no-fused-up-gate" ) {
743752 if (++i >= argc) {
744753 invalid_param = true ;
@@ -829,6 +838,7 @@ struct cmd_params_instance {
829838 bool embeddings;
830839 bool repack = false ;
831840 bool fmoe = false ;
841+ bool ger = false ;
832842 bool no_fug = false ;
833843 bool use_thp = false ;
834844 const llama_model_tensor_buft_override* buft_overrides;
@@ -876,6 +886,7 @@ struct cmd_params_instance {
876886 cparams.mla_attn = mla_attn;
877887 cparams.attn_max_batch = attn_max_batch;
878888 cparams.fused_moe_up_gate = fmoe;
889+ cparams.grouped_expert_routing = ger;
879890 cparams.fused_up_gate = !no_fug;
880891 cparams.min_experts = ser.first ;
881892 cparams.thresh_experts = ser.second ;
@@ -935,6 +946,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
935946 /* .embeddings = */ embd,
936947 /* .repack = */ params.repack ,
937948 /* .fmoe = */ params.fmoe ,
949+ /* .ger = */ params.ger ,
938950 /* .no_fug = */ params.no_fug ,
939951 /* .use_thp = */ params.use_thp ,
940952 /* .buft_overrides=*/ params.buft_overrides .data (),
@@ -970,6 +982,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
970982 /* .embeddings = */ embd,
971983 /* .repack = */ params.repack ,
972984 /* .fmoe = */ params.fmoe ,
985+ /* .ger = */ params.ger ,
973986 /* .no_fug = */ params.no_fug ,
974987 /* .use_thp = */ params.use_thp ,
975988 /* .buft_overrides=*/ params.buft_overrides .data (),
@@ -1005,6 +1018,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10051018 /* .embeddings = */ embd,
10061019 /* .repack = */ params.repack ,
10071020 /* .fmoe = */ params.fmoe ,
1021+ /* .ger = */ params.ger ,
10081022 /* .no_fug = */ params.no_fug ,
10091023 /* .use_thp = */ params.use_thp ,
10101024 /* .buft_overrides=*/ params.buft_overrides .data (),
@@ -1040,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10401054 /* .embeddings = */ embd,
10411055 /* .repack = */ params.repack ,
10421056 /* .fmoe = */ params.fmoe ,
1057+ /* .ger = */ params.ger ,
10431058 /* .no_fug = */ params.no_fug ,
10441059 /* .use_thp = */ params.use_thp ,
10451060 /* .buft_overrides=*/ params.buft_overrides .data (),
@@ -1086,6 +1101,7 @@ struct test {
10861101 bool embeddings;
10871102 bool repack = false ;
10881103 bool fmoe = false ;
1104+ bool ger = false ;
10891105 bool no_fug = false ;
10901106 bool use_thp = false ;
10911107 int n_prompt;
@@ -1120,6 +1136,8 @@ struct test {
11201136 use_mmap = inst.use_mmap ;
11211137 embeddings = inst.embeddings ;
11221138 repack = inst.repack ;
1139+ fmoe = inst.fmoe ;
1140+ ger = inst.ger ;
11231141 no_fug = inst.no_fug ;
11241142 use_thp = inst.use_thp ;
11251143 n_prompt = inst.n_prompt ;
@@ -1212,7 +1230,7 @@ struct test {
12121230 " n_threads" , " type_k" , " type_v" ,
12131231 " n_gpu_layers" , " split_mode" ,
12141232 " main_gpu" , " no_kv_offload" , " flash_attn" , " mla_attn" , " attn_max_batch" , " ser" ,
1215- " tensor_split" , " use_mmap" , " embeddings" , " repack" , " fused_moe" , " fused_up_gate" , " use_thp" ,
1233+ " tensor_split" , " use_mmap" , " embeddings" , " repack" , " fused_moe" , " grouped_er " , " fused_up_gate" , " use_thp" ,
12161234 " n_prompt" , " n_gen" , " test_time" ,
12171235 " avg_ns" , " stddev_ns" ,
12181236 " avg_ts" , " stddev_ts" , " test" ,
@@ -1234,7 +1252,7 @@ struct test {
12341252 if (field == " cuda" || field == " vulkan" || field == " kompute" || field == " metal" ||
12351253 field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
12361254 field == " flash_attn" || field == " use_mmap" || field == " embeddings" || field == " repack" || field == " use_thp" ||
1237- field == " fused_moe" || field == " fused_up_gate" ) {
1255+ field == " fused_moe" || field == " grouped_er " || field == " fused_up_gate" ) {
12381256 return BOOL;
12391257 }
12401258 if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -1277,7 +1295,8 @@ struct test {
12771295 std::to_string (main_gpu), std::to_string (no_kv_offload), std::to_string (flash_attn),
12781296 std::to_string (mla_attn), std::to_string (attn_max_batch), ser_to_string (ser),
12791297 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
1280- std::to_string (repack), std::to_string (fmoe), std::to_string (no_fug), std::to_string (use_thp),
1298+ std::to_string (repack), std::to_string (fmoe), std::to_string (ger),
1299+ std::to_string (no_fug), std::to_string (use_thp),
12811300 std::to_string (n_prompt), std::to_string (n_gen), test_time,
12821301 std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
12831302 std::to_string (avg_ts ()), std::to_string (stdev_ts ()),
@@ -1461,6 +1480,9 @@ struct markdown_printer : public printer {
14611480 if (field == " fused_moe" ) {
14621481 return 4 ;
14631482 }
1483+ if (field == " grouped_er" ) {
1484+ return 3 ;
1485+ }
14641486 if (field == " fused_up_gate" ) {
14651487 return 6 ;
14661488 }
@@ -1513,6 +1535,12 @@ struct markdown_printer : public printer {
15131535 if (field == " fused_moe" ) {
15141536 return " fmoe" ;
15151537 }
1538+ if (field == " grouped_er" ) {
1539+ return " ger" ;
1540+ }
1541+ if (field == " grouped_er" ) {
1542+ return " ger" ;
1543+ }
15161544 if (field == " fused_up_gate" ) {
15171545 return " no-fug" ;
15181546 }
@@ -1589,6 +1617,9 @@ struct markdown_printer : public printer {
15891617 if (params.fmoe != cmd_params_defaults.fmoe ) {
15901618 fields.emplace_back (" fused_moe" );
15911619 }
1620+ if (params.ger != cmd_params_defaults.ger ) {
1621+ fields.emplace_back (" grouped_er" );
1622+ }
15921623 if (params.no_fug != cmd_params_defaults.no_fug ) {
15931624 fields.emplace_back (" fused_up_gate" );
15941625 }
0 commit comments