@@ -265,6 +265,7 @@ struct cmd_params {
265265 bool no_fug = false ;
266266 bool use_thp = false ;
267267 bool no_ooae = false ;
268+ bool mqkv = false ;
268269 output_formats output_format;
269270 output_formats output_format_stderr;
270271};
@@ -303,6 +304,7 @@ static const cmd_params cmd_params_defaults = {
303304 /* no_fug */ false ,
304305 /* use_thp */ false ,
305306 /* no_ooae */ false ,
307+ /* mqkv */ false ,
306308 /* output_format */ MARKDOWN,
307309 /* output_format_stderr */ NONE,
308310};
@@ -342,6 +344,7 @@ static void print_usage(int /* argc */, char ** argv) {
342344 printf (" -v, --verbose (default: %s)\n " , cmd_params_defaults.verbose ? " 1" : " 0" );
343345 printf (" -w, --warmup <0|1> (default: %s)\n " , cmd_params_defaults.warmup ? " 1" : " 0" );
344346 printf (" -rtr, --run-time-repack <0|1> (default: %s)\n " , cmd_params_defaults.repack ? " 1" : " 0" );
347+ printf (" -mqkv, --merge-qkv (default: %s)\n " , cmd_params_defaults.mqkv ? " 1" : " 0" );
345348 printf (" -thp, --transparent-huge-pages <0|1> (default: %s)\n " , cmd_params_defaults.use_thp ? " 1" : " 0" );
346349 printf (" -ot, --override-tensor pattern (default: none)\n " );
347350 printf (" -fmoe, --fused-moe <0|1> (default: %s)\n " , cmd_params_defaults.fmoe ? " 1" : " 0" );
@@ -733,6 +736,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
733736 break ;
734737 }
735738 params.repack = std::stoi (argv[i]);
739+ } else if (arg == " -mqkv" || arg == " --merge-qkv" ) {
740+ if (++i >= argc) {
741+ invalid_param = true ;
742+ break ;
743+ }
744+ params.mqkv = std::stoi (argv[i]);
736745 } else if (arg == " -thp" || arg == " --transparent-huge-pages" ) {
737746 if (++i >= argc) {
738747 invalid_param = true ;
@@ -851,6 +860,7 @@ struct cmd_params_instance {
851860 bool no_fug = false ;
852861 bool use_thp = false ;
853862 bool no_ooae = false ;
863+ bool mqkv = false ;
854864 const llama_model_tensor_buft_override* buft_overrides;
855865
856866 llama_model_params to_llama_mparams () const {
@@ -866,6 +876,7 @@ struct cmd_params_instance {
866876 mparams.use_mmap = use_mmap;
867877 mparams.repack_tensors = repack;
868878 mparams.use_thp = use_thp;
879+ mparams.merge_qkv = mqkv;
869880 mparams.tensor_buft_overrides = buft_overrides;
870881
871882 return mparams;
@@ -879,6 +890,7 @@ struct cmd_params_instance {
879890 main_gpu == other.main_gpu &&
880891 use_mmap == other.use_mmap &&
881892 repack == other.repack &&
893+ mqkv == other.mqkv &&
882894 use_thp == other.use_thp &&
883895 tensor_split == other.tensor_split ;
884896 }
@@ -961,6 +973,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
961973 /* .no_fug = */ params.no_fug ,
962974 /* .use_thp = */ params.use_thp ,
963975 /* .no_ooae = */ params.no_ooae ,
976+ /* .mqkv = */ params.mqkv ,
964977 /* .buft_overrides=*/ params.buft_overrides .data (),
965978 };
966979 instances.push_back (instance);
@@ -998,6 +1011,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
9981011 /* .no_fug = */ params.no_fug ,
9991012 /* .use_thp = */ params.use_thp ,
10001013 /* .no_ooae = */ params.no_ooae ,
1014+ /* .mqkv = */ params.mqkv ,
10011015 /* .buft_overrides=*/ params.buft_overrides .data (),
10021016 };
10031017 instances.push_back (instance);
@@ -1035,6 +1049,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10351049 /* .no_fug = */ params.no_fug ,
10361050 /* .use_thp = */ params.use_thp ,
10371051 /* .no_ooae = */ params.no_ooae ,
1052+ /* .mqkv = */ params.mqkv ,
10381053 /* .buft_overrides=*/ params.buft_overrides .data (),
10391054 };
10401055 instances.push_back (instance);
@@ -1071,7 +1086,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10711086 /* .ger = */ params.ger ,
10721087 /* .no_fug = */ params.no_fug ,
10731088 /* .use_thp = */ params.use_thp ,
1074- /* .no_ooae = */ params.no_ooae ,
1089+ /* .no_ooae = */ params.no_ooae ,
1090+ /* .mqkv = */ params.mqkv ,
10751091 /* .buft_overrides=*/ params.buft_overrides .data (),
10761092 };
10771093 instances.push_back (instance);
@@ -1120,6 +1136,7 @@ struct test {
11201136 bool no_fug = false ;
11211137 bool use_thp = false ;
11221138 bool no_ooae = false ;
1139+ bool mqkv = false ;
11231140 int n_prompt;
11241141 int n_gen;
11251142 std::string test_time;
@@ -1152,6 +1169,7 @@ struct test {
11521169 use_mmap = inst.use_mmap ;
11531170 embeddings = inst.embeddings ;
11541171 repack = inst.repack ;
1172+ mqkv = inst.mqkv ;
11551173 fmoe = inst.fmoe ;
11561174 ger = inst.ger ;
11571175 no_fug = inst.no_fug ;
@@ -1247,7 +1265,7 @@ struct test {
12471265 " n_threads" , " type_k" , " type_v" ,
12481266 " n_gpu_layers" , " split_mode" ,
12491267 " main_gpu" , " no_kv_offload" , " flash_attn" , " mla_attn" , " attn_max_batch" , " ser" ,
1250- " tensor_split" , " use_mmap" , " embeddings" , " repack" , " fused_moe" , " grouped_er" , " fused_up_gate" , " use_thp" , " ooae" ,
1268+ " tensor_split" , " use_mmap" , " embeddings" , " repack" , " mqkv " , " fused_moe" , " grouped_er" , " fused_up_gate" , " use_thp" , " ooae" ,
12511269 " n_prompt" , " n_gen" , " test_time" ,
12521270 " avg_ns" , " stddev_ns" ,
12531271 " avg_ts" , " stddev_ts" , " test" ,
@@ -1269,7 +1287,7 @@ struct test {
12691287 if (field == " cuda" || field == " vulkan" || field == " kompute" || field == " metal" ||
12701288 field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
12711289 field == " flash_attn" || field == " use_mmap" || field == " embeddings" || field == " repack" || field == " use_thp" ||
1272- field == " fused_moe" || field == " grouped_er" || field == " fused_up_gate" || field == " ooae" ) {
1290+ field == " fused_moe" || field == " grouped_er" || field == " fused_up_gate" || field == " ooae" || field == " mqkv " ) {
12731291 return BOOL;
12741292 }
12751293 if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -1313,7 +1331,7 @@ struct test {
13131331 std::to_string (mla_attn), std::to_string (attn_max_batch), ser_to_string (ser),
13141332 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
13151333 std::to_string (repack), std::to_string (fmoe), std::to_string (ger),
1316- std::to_string (no_fug), std::to_string (use_thp), std::to_string (no_ooae),
1334+ std::to_string (no_fug), std::to_string (use_thp), std::to_string (no_ooae), std::to_string (mqkv),
13171335 std::to_string (n_prompt), std::to_string (n_gen), test_time,
13181336 std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
13191337 std::to_string (avg_ts ()), std::to_string (stdev_ts ()),
@@ -1491,6 +1509,9 @@ struct markdown_printer : public printer {
14911509 if (field == " repack" ) {
14921510 return 3 ;
14931511 }
1512+ if (field == " mqkv" ) {
1513+ return 4 ;
1514+ }
14941515 if (field == " use_thp" ) {
14951516 return 3 ;
14961517 }
@@ -1549,6 +1570,9 @@ struct markdown_printer : public printer {
15491570 if (field == " repack" ) {
15501571 return " rtr" ;
15511572 }
1573+ if (field == " mqkv" ) {
1574+ return " mqkv" ;
1575+ }
15521576 if (field == " use_thp" ) {
15531577 return " thp" ;
15541578 }
@@ -1634,6 +1658,9 @@ struct markdown_printer : public printer {
16341658 if (params.repack != cmd_params_defaults.repack ) {
16351659 fields.emplace_back (" repack" );
16361660 }
1661+ if (params.mqkv != cmd_params_defaults.mqkv ) {
1662+ fields.emplace_back (" mqkv" );
1663+ }
16371664 if (params.use_thp != cmd_params_defaults.use_thp ) {
16381665 fields.emplace_back (" use_thp" );
16391666 }
0 commit comments