@@ -251,6 +251,7 @@ struct cmd_params {
251251 std::vector<int > mla_attn;
252252 std::vector<int > attn_max_batch;
253253 std::vector<Ser> ser;
254+ std::vector<bool > reuse;
254255 std::vector<std::vector<float >> tensor_split;
255256 std::vector<bool > use_mmap;
256257 std::vector<bool > embeddings;
@@ -292,6 +293,7 @@ static const cmd_params cmd_params_defaults = {
292293 /* mla_attn */ {3 },
293294 /* attn_max_batch */ {0 },
294295 /* ser */ {{-1 ,0 .0f }},
296+ /* reuse */ {false },
295297 /* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
296298 /* use_mmap */ {true },
297299 /* embeddings */ {false },
@@ -339,6 +341,7 @@ static void print_usage(int /* argc */, char ** argv) {
339341 printf (" -mla, --mla-attn <0|1|2> (default: %s)\n " , join (cmd_params_defaults.mla_attn , " ," ).c_str ());
340342 printf (" -amb, --attn-max-batch <i> (default: %s)\n " , join (cmd_params_defaults.attn_max_batch , " ," ).c_str ());
341343 printf (" -ser, --smart-expert-reduction <i,f>(default: %s)\n " , join (cmd_params_defaults.attn_max_batch , " ," ).c_str ());
344+ printf (" -gr, --graph-reuse <0|1> (default: %s)\n " , join (cmd_params_defaults.reuse , " ," ).c_str ());
342345 printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
343346 printf (" --numa <distribute|isolate|numactl> (default: disabled)\n " );
344347 printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
@@ -681,6 +684,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
681684 }
682685 auto p = string_split<int >(argv[i], split_delim);
683686 params.attn_max_batch .insert (params.attn_max_batch .end (), p.begin (), p.end ());
687+ } else if (arg == " -gr" || arg == " --graph-reuse" ) {
688+ if (++i >= argc) {
689+ invalid_param = true ;
690+ break ;
691+ }
692+ auto p = string_split<bool >(argv[i], split_delim);
693+ params.reuse .insert (params.reuse .end (), p.begin (), p.end ());
684694 } else if (arg == " -ser" || arg == " --smart-expert-reduction" ) {
685695 if (++i >= argc) {
686696 invalid_param = true ;
@@ -852,6 +862,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
852862 if (params.flash_attn .empty ()) { params.flash_attn = cmd_params_defaults.flash_attn ; }
853863 if (params.mla_attn .empty ()) { params.mla_attn = cmd_params_defaults.mla_attn ; }
854864 if (params.attn_max_batch .empty ()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch ; }
865+ if (params.reuse .empty ()) { params.reuse = cmd_params_defaults.reuse ; }
855866 if (params.ser .empty ()) { params.ser = cmd_params_defaults.ser ; }
856867 if (params.tensor_split .empty ()) { params.tensor_split = cmd_params_defaults.tensor_split ; }
857868 if (params.use_mmap .empty ()) { params.use_mmap = cmd_params_defaults.use_mmap ; }
@@ -891,6 +902,7 @@ struct cmd_params_instance {
891902 bool flash_attn;
892903 int mla_attn;
893904 int attn_max_batch;
905+ bool reuse;
894906 Ser ser;
895907 std::vector<float > tensor_split;
896908 std::string cuda_params;
@@ -950,6 +962,7 @@ struct cmd_params_instance {
950962 cparams.flash_attn = flash_attn;
951963 cparams.mla_attn = mla_attn;
952964 cparams.attn_max_batch = attn_max_batch;
965+ cparams.graph_reuse = reuse;
953966 cparams.fused_moe_up_gate = fmoe;
954967 cparams.grouped_expert_routing = ger;
955968 cparams.rope_cache = rcache;
@@ -984,6 +997,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
984997 for (const auto & fa : params.flash_attn )
985998 for (const auto & mla : params.mla_attn )
986999 for (const auto & amb : params.attn_max_batch )
1000+ for (const auto & reuse : params.reuse )
9871001 for (const auto & ser : params.ser )
9881002 for (const auto & nt : params.n_threads ) {
9891003 for (const auto & n_prompt : params.n_prompt ) {
@@ -1008,6 +1022,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10081022 /* .flash_attn = */ fa,
10091023 /* .mla_attn = */ mla,
10101024 /* .attn_max_b = */ amb,
1025+ /* .reuse = */ reuse,
10111026 /* .ser = */ ser,
10121027 /* .tensor_split = */ ts,
10131028 /* .cuda_params = */ params.cuda_params ,
@@ -1048,6 +1063,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10481063 /* .flash_attn = */ fa,
10491064 /* .mla_attn = */ mla,
10501065 /* .attn_max_b = */ amb,
1066+ /* .reuse = */ reuse,
10511067 /* .ser = */ ser,
10521068 /* .tensor_split = */ ts,
10531069 /* .cuda_params = */ params.cuda_params ,
@@ -1088,6 +1104,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10881104 /* .flash_attn = */ fa,
10891105 /* .mla_attn = */ mla,
10901106 /* .attn_max_b = */ amb,
1107+ /* .reuse = */ reuse,
10911108 /* .ser = */ ser,
10921109 /* .tensor_split = */ ts,
10931110 /* .cuda_params = */ params.cuda_params ,
@@ -1128,6 +1145,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
11281145 /* .flash_attn = */ fa,
11291146 /* .mla_attn = */ mla,
11301147 /* .attn_max_b = */ amb,
1148+ /* .reuse = */ reuse,
11311149 /* .ser = */ ser,
11321150 /* .tensor_split = */ ts,
11331151 /* .cuda_params = */ params.cuda_params ,
@@ -1179,6 +1197,7 @@ struct test {
11791197 bool flash_attn;
11801198 int mla_attn;
11811199 int attn_max_batch;
1200+ bool reuse;
11821201 Ser ser;
11831202 std::vector<float > tensor_split;
11841203 std::string cuda_params;
@@ -1219,6 +1238,7 @@ struct test {
12191238 flash_attn = inst.flash_attn ;
12201239 mla_attn = inst.mla_attn ;
12211240 attn_max_batch = inst.attn_max_batch ;
1241+ reuse = inst.reuse ;
12221242 ser = inst.ser ;
12231243 tensor_split = inst.tensor_split ;
12241244 cuda_params = inst.cuda_params ;
@@ -1321,7 +1341,7 @@ struct test {
13211341 " n_batch" , " n_ubatch" ,
13221342 " n_threads" , " type_k" , " type_v" ,
13231343 " n_gpu_layers" , " split_mode" ,
1324- " main_gpu" , " no_kv_offload" , " flash_attn" , " mla_attn" , " attn_max_batch" , " ser" ,
1344+ " main_gpu" , " no_kv_offload" , " flash_attn" , " mla_attn" , " attn_max_batch" , " ser" , " reuse " ,
13251345 " tensor_split" , " use_mmap" , " embeddings" , " repack" , " mqkv" , " fused_moe" , " grouped_er" ,
13261346 " fused_up_gate" , " use_thp" , " ooae" , " rcache" ,
13271347 " n_prompt" , " n_gen" , " test_time" ,
@@ -1346,7 +1366,7 @@ struct test {
13461366 field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
13471367 field == " flash_attn" || field == " use_mmap" || field == " embeddings" || field == " repack" || field == " use_thp" ||
13481368 field == " fused_moe" || field == " grouped_er" || field == " fused_up_gate" || field == " ooae" || field == " mqkv" ||
1349- field == " rcache" ) {
1369+ field == " rcache" || field == " reuse " ) {
13501370 return BOOL;
13511371 }
13521372 if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -1387,7 +1407,7 @@ struct test {
13871407 std::to_string (is_gen ? n_threads.first : n_threads.second ), ggml_type_name (type_k), ggml_type_name (type_v),
13881408 std::to_string (n_gpu_layers), split_mode_str (split_mode),
13891409 std::to_string (main_gpu), std::to_string (no_kv_offload), std::to_string (flash_attn),
1390- std::to_string (mla_attn), std::to_string (attn_max_batch), ser_to_string (ser),
1410+ std::to_string (mla_attn), std::to_string (attn_max_batch), ser_to_string (ser), std::to_string (reuse),
13911411 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
13921412 std::to_string (repack), std::to_string (fmoe), std::to_string (ger), std::to_string (rcache),
13931413 std::to_string (no_fug), std::to_string (use_thp), std::to_string (no_ooae), std::to_string (mqkv),
@@ -1559,6 +1579,9 @@ struct markdown_printer : public printer {
15591579 if (field == " attn_max_batch" ) {
15601580 return 5 ;
15611581 }
1582+ if (field == " reuse" ) {
1583+ return 2 ;
1584+ }
15621585 if (field == " ser" ) {
15631586 return 10 ;
15641587 }
@@ -1623,7 +1646,10 @@ struct markdown_printer : public printer {
16231646 if (field == " attn_max_batch" ) {
16241647 return " amb" ;
16251648 }
1626- if (field == " attn_max_batch" ) {
1649+ if (field == " reuse" ) {
1650+ return " gr" ;
1651+ }
1652+ if (field == " ser" ) {
16271653 return " ser" ;
16281654 }
16291655 if (field == " use_mmap" ) {
@@ -1702,9 +1728,12 @@ struct markdown_printer : public printer {
17021728 if (params.mla_attn .size () > 1 || params.mla_attn != cmd_params_defaults.mla_attn ) {
17031729 fields.emplace_back (" mla_attn" );
17041730 }
1705- if (params.attn_max_batch .size () > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn ) {
1731+ if (params.attn_max_batch .size () > 1 || params.attn_max_batch != cmd_params_defaults.attn_max_batch ) {
17061732 fields.emplace_back (" attn_max_batch" );
17071733 }
1734+ if (params.reuse .size () > 1 || params.reuse != cmd_params_defaults.reuse ) {
1735+ fields.emplace_back (" reuse" );
1736+ }
17081737 if (params.ser .size () > 1 || params.ser != cmd_params_defaults.ser ) {
17091738 fields.emplace_back (" ser" );
17101739 }
0 commit comments