@@ -232,6 +232,7 @@ struct cmd_params {
232232 std::vector<int > main_gpu;
233233 std::vector<bool > no_kv_offload;
234234 std::vector<bool > flash_attn;
235+ std::vector<bool > mla_attn;
235236 std::vector<std::vector<float >> tensor_split;
236237 std::vector<bool > use_mmap;
237238 std::vector<bool > embeddings;
@@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = {
261262 /* main_gpu */ {0 },
262263 /* no_kv_offload */ {false },
263264 /* flash_attn */ {false },
265+ /* mla_attn */ {false },
264266 /* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
265267 /* use_mmap */ {true },
266268 /* embeddings */ {false },
@@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) {
294296 printf (" -mg, --main-gpu <i> (default: %s)\n " , join (cmd_params_defaults.main_gpu , " ," ).c_str ());
295297 printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
296298 printf (" -fa, --flash-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.flash_attn , " ," ).c_str ());
299+ printf (" -mla, --mla-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.mla_attn , " ," ).c_str ());
297300 printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
298301 printf (" --numa <distribute|isolate|numactl> (default: disabled)\n " );
299302 printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
@@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
526529 }
527530 auto p = string_split<bool >(argv[i], split_delim);
528531 params.flash_attn .insert (params.flash_attn .end (), p.begin (), p.end ());
532+ } else if (arg == " -mla" || arg == " --mla-attn" ) {
533+ if (++i >= argc) {
534+ invalid_param = true ;
535+ break ;
536+ }
537+ auto p = string_split<bool >(argv[i], split_delim);
538+ params.mla_attn .insert (params.mla_attn .end (), p.begin (), p.end ());
529539 } else if (arg == " -mmp" || arg == " --mmap" ) {
530540 if (++i >= argc) {
531541 invalid_param = true ;
@@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
621631 if (params.main_gpu .empty ()) { params.main_gpu = cmd_params_defaults.main_gpu ; }
622632 if (params.no_kv_offload .empty ()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload ; }
623633 if (params.flash_attn .empty ()) { params.flash_attn = cmd_params_defaults.flash_attn ; }
634+ if (params.mla_attn .empty ()) { params.mla_attn = cmd_params_defaults.mla_attn ; }
624635 if (params.tensor_split .empty ()) { params.tensor_split = cmd_params_defaults.tensor_split ; }
625636 if (params.use_mmap .empty ()) { params.use_mmap = cmd_params_defaults.use_mmap ; }
626637 if (params.embeddings .empty ()) { params.embeddings = cmd_params_defaults.embeddings ; }
@@ -656,6 +667,7 @@ struct cmd_params_instance {
656667 int main_gpu;
657668 bool no_kv_offload;
658669 bool flash_attn;
670+ bool mla_attn;
659671 std::vector<float > tensor_split;
660672 bool use_mmap;
661673 bool embeddings;
@@ -698,6 +710,7 @@ struct cmd_params_instance {
698710 cparams.type_v = type_v;
699711 cparams.offload_kqv = !no_kv_offload;
700712 cparams.flash_attn = flash_attn;
713+ cparams.mla_attn = mla_attn;
701714 cparams.embeddings = embeddings;
702715
703716 return cparams;
@@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
722735 for (const auto & tv : params.type_v )
723736 for (const auto & nkvo : params.no_kv_offload )
724737 for (const auto & fa : params.flash_attn )
738+ for (const auto & mla : params.mla_attn )
725739 for (const auto & nt : params.n_threads ) {
726740 for (const auto & n_prompt : params.n_prompt ) {
727741 if (n_prompt == 0 ) {
@@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
743757 /* .main_gpu = */ mg,
744758 /* .no_kv_offload= */ nkvo,
745759 /* .flash_attn = */ fa,
760+ /* .mla_attn = */ mla,
746761 /* .tensor_split = */ ts,
747762 /* .use_mmap = */ mmp,
748763 /* .embeddings = */ embd,
@@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
771786 /* .main_gpu = */ mg,
772787 /* .no_kv_offload= */ nkvo,
773788 /* .flash_attn = */ fa,
789+ /* .mla_attn = */ mla,
774790 /* .tensor_split = */ ts,
775791 /* .use_mmap = */ mmp,
776792 /* .embeddings = */ embd,
@@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
799815 /* .main_gpu = */ mg,
800816 /* .no_kv_offload= */ nkvo,
801817 /* .flash_attn = */ fa,
818+ /* .mla_attn = */ mla,
802819 /* .tensor_split = */ ts,
803820 /* .use_mmap = */ mmp,
804821 /* .embeddings = */ embd,
@@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
827844 /* .main_gpu = */ mg,
828845 /* .no_kv_offload= */ nkvo,
829846 /* .flash_attn = */ fa,
847+ /* .mla_attn = */ mla,
830848 /* .tensor_split = */ ts,
831849 /* .use_mmap = */ mmp,
832850 /* .embeddings = */ embd,
@@ -866,6 +884,7 @@ struct test {
866884 int main_gpu;
867885 bool no_kv_offload;
868886 bool flash_attn;
887+ bool mla_attn;
869888 std::vector<float > tensor_split;
870889 bool use_mmap;
871890 bool embeddings;
@@ -895,6 +914,7 @@ struct test {
895914 main_gpu = inst.main_gpu ;
896915 no_kv_offload = inst.no_kv_offload ;
897916 flash_attn = inst.flash_attn ;
917+ mla_attn = inst.mla_attn ;
898918 tensor_split = inst.tensor_split ;
899919 use_mmap = inst.use_mmap ;
900920 embeddings = inst.embeddings ;
@@ -988,7 +1008,7 @@ struct test {
9881008 " n_batch" , " n_ubatch" ,
9891009 " n_threads" , " type_k" , " type_v" ,
9901010 " n_gpu_layers" , " split_mode" ,
991- " main_gpu" , " no_kv_offload" , " flash_attn" ,
1011+ " main_gpu" , " no_kv_offload" , " flash_attn" , " mla_attn " ,
9921012 " tensor_split" , " use_mmap" , " embeddings" , " repack" ,
9931013 " n_prompt" , " n_gen" , " test_time" ,
9941014 " avg_ns" , " stddev_ns" ,
@@ -1010,7 +1030,7 @@ struct test {
10101030 }
10111031 if (field == " cuda" || field == " vulkan" || field == " kompute" || field == " metal" ||
10121032 field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
1013- field == " flash_attn" || field == " use_mmap" || field == " embeddings" || field == " repack" ) {
1033+ field == " flash_attn" || field == " mla_attn " || field == " use_mmap" || field == " embeddings" || field == " repack" ) {
10141034 return BOOL;
10151035 }
10161036 if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -1044,7 +1064,7 @@ struct test {
10441064 std::to_string (n_batch), std::to_string (n_ubatch),
10451065 std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
10461066 std::to_string (n_gpu_layers), split_mode_str (split_mode),
1047- std::to_string (main_gpu), std::to_string (no_kv_offload), std::to_string (flash_attn),
1067+ std::to_string (main_gpu), std::to_string (no_kv_offload), std::to_string (flash_attn), std::to_string (mla_attn),
10481068 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings), std::to_string (repack),
10491069 std::to_string (n_prompt), std::to_string (n_gen), test_time,
10501070 std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
@@ -1208,6 +1228,9 @@ struct markdown_printer : public printer {
12081228 if (field == " flash_attn" ) {
12091229 return 2 ;
12101230 }
1231+ if (field == " mla_attn" ) {
1232+ return 3 ;
1233+ }
12111234 if (field == " use_mmap" ) {
12121235 return 4 ;
12131236 }
@@ -1242,6 +1265,9 @@ struct markdown_printer : public printer {
12421265 if (field == " flash_attn" ) {
12431266 return " fa" ;
12441267 }
1268+ if (field == " mla_attn" ) {
1269+ return " mla" ;
1270+ }
12451271 if (field == " use_mmap" ) {
12461272 return " mmap" ;
12471273 }
@@ -1294,6 +1320,9 @@ struct markdown_printer : public printer {
12941320 if (params.flash_attn .size () > 1 || params.flash_attn != cmd_params_defaults.flash_attn ) {
12951321 fields.emplace_back (" flash_attn" );
12961322 }
1323+ if (params.mla_attn .size () > 1 || params.mla_attn != cmd_params_defaults.mla_attn ) {
1324+ fields.emplace_back (" mla_attn" );
1325+ }
12971326 if (params.tensor_split .size () > 1 || params.tensor_split != cmd_params_defaults.tensor_split ) {
12981327 fields.emplace_back (" tensor_split" );
12991328 }
0 commit comments