@@ -174,6 +174,7 @@ struct cmd_params {
174174 std::vector<llama_split_mode> split_mode;
175175 std::vector<int > main_gpu;
176176 std::vector<bool > no_kv_offload;
177+ std::vector<bool > flash_attn;
177178 std::vector<std::vector<float >> tensor_split;
178179 std::vector<bool > use_mmap;
179180 std::vector<bool > embeddings;
@@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
195196 /* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
196197 /* main_gpu */ {0 },
197198 /* no_kv_offload */ {false },
199+ /* flash_attn */ {false },
198200 /* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
199201 /* use_mmap */ {true },
200202 /* embeddings */ {false },
@@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
220222 printf (" -sm, --split-mode <none|layer|row> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.split_mode , split_mode_str), " ," ).c_str ());
221223 printf (" -mg, --main-gpu <i> (default: %s)\n " , join (cmd_params_defaults.main_gpu , " ," ).c_str ());
222224 printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
225+ printf (" -fa, --flash-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.flash_attn , " ," ).c_str ());
223226 printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
224227 printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
225228 printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
@@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
393396 }
394397 auto p = split<bool >(argv[i], split_delim);
395398 params.no_kv_offload .insert (params.no_kv_offload .end (), p.begin (), p.end ());
399+ } else if (arg == " -fa" || arg == " --flash-attn" ) {
400+ if (++i >= argc) {
401+ invalid_param = true ;
402+ break ;
403+ }
404+ auto p = split<bool >(argv[i], split_delim);
405+ params.flash_attn .insert (params.flash_attn .end (), p.begin (), p.end ());
396406 } else if (arg == " -mmp" || arg == " --mmap" ) {
397407 if (++i >= argc) {
398408 invalid_param = true ;
@@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
477487 if (params.split_mode .empty ()) { params.split_mode = cmd_params_defaults.split_mode ; }
478488 if (params.main_gpu .empty ()) { params.main_gpu = cmd_params_defaults.main_gpu ; }
479489 if (params.no_kv_offload .empty ()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload ; }
490+ if (params.flash_attn .empty ()) { params.flash_attn = cmd_params_defaults.flash_attn ; }
480491 if (params.tensor_split .empty ()) { params.tensor_split = cmd_params_defaults.tensor_split ; }
481492 if (params.use_mmap .empty ()) { params.use_mmap = cmd_params_defaults.use_mmap ; }
482493 if (params.embeddings .empty ()) { params.embeddings = cmd_params_defaults.embeddings ; }
@@ -498,6 +509,7 @@ struct cmd_params_instance {
498509 llama_split_mode split_mode;
499510 int main_gpu;
500511 bool no_kv_offload;
512+ bool flash_attn;
501513 std::vector<float > tensor_split;
502514 bool use_mmap;
503515 bool embeddings;
@@ -532,6 +544,7 @@ struct cmd_params_instance {
532544 cparams.type_k = type_k;
533545 cparams.type_v = type_v;
534546 cparams.offload_kqv = !no_kv_offload;
547+ cparams.flash_attn = flash_attn;
535548 cparams.embeddings = embeddings;
536549
537550 return cparams;
@@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
554567 for (const auto & tk : params.type_k )
555568 for (const auto & tv : params.type_v )
556569 for (const auto & nkvo : params.no_kv_offload )
570+ for (const auto & fa : params.flash_attn )
557571 for (const auto & nt : params.n_threads ) {
558572 for (const auto & n_prompt : params.n_prompt ) {
559573 if (n_prompt == 0 ) {
@@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
572586 /* .split_mode = */ sm,
573587 /* .main_gpu = */ mg,
574588 /* .no_kv_offload= */ nkvo,
589+ /* .flash_attn = */ fa,
575590 /* .tensor_split = */ ts,
576591 /* .use_mmap = */ mmp,
577592 /* .embeddings = */ embd,
@@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
596611 /* .split_mode = */ sm,
597612 /* .main_gpu = */ mg,
598613 /* .no_kv_offload= */ nkvo,
614+ /* .flash_attn = */ fa,
599615 /* .tensor_split = */ ts,
600616 /* .use_mmap = */ mmp,
601617 /* .embeddings = */ embd,
@@ -633,6 +649,7 @@ struct test {
633649 llama_split_mode split_mode;
634650 int main_gpu;
635651 bool no_kv_offload;
652+ bool flash_attn;
636653 std::vector<float > tensor_split;
637654 bool use_mmap;
638655 bool embeddings;
@@ -657,6 +674,7 @@ struct test {
657674 split_mode = inst.split_mode ;
658675 main_gpu = inst.main_gpu ;
659676 no_kv_offload = inst.no_kv_offload ;
677+ flash_attn = inst.flash_attn ;
660678 tensor_split = inst.tensor_split ;
661679 use_mmap = inst.use_mmap ;
662680 embeddings = inst.embeddings ;
@@ -731,7 +749,7 @@ struct test {
731749 " n_batch" , " n_ubatch" ,
732750 " n_threads" , " type_k" , " type_v" ,
733751 " n_gpu_layers" , " split_mode" ,
734- " main_gpu" , " no_kv_offload" ,
752+ " main_gpu" , " no_kv_offload" , " flash_attn " ,
735753 " tensor_split" , " use_mmap" , " embeddings" ,
736754 " n_prompt" , " n_gen" , " test_time" ,
737755 " avg_ns" , " stddev_ns" ,
@@ -753,7 +771,7 @@ struct test {
753771 }
754772 if (field == " cuda" || field == " opencl" || field == " vulkan" || field == " kompute" || field == " metal" ||
755773 field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
756- field == " use_mmap" || field == " embeddings" ) {
774+ field == " flash_attn " || field == " use_mmap" || field == " embeddings" ) {
757775 return BOOL;
758776 }
759777 if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -787,7 +805,7 @@ struct test {
787805 std::to_string (n_batch), std::to_string (n_ubatch),
788806 std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
789807 std::to_string (n_gpu_layers), split_mode_str (split_mode),
790- std::to_string (main_gpu), std::to_string (no_kv_offload),
808+ std::to_string (main_gpu), std::to_string (no_kv_offload), std::to_string (flash_attn),
791809 tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
792810 std::to_string (n_prompt), std::to_string (n_gen), test_time,
793811 std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
@@ -955,6 +973,9 @@ struct markdown_printer : public printer {
955973 if (field == " no_kv_offload" ) {
956974 return " nkvo" ;
957975 }
976+ if (field == " flash_attn" ) {
977+ return " fa" ;
978+ }
958979 if (field == " use_mmap" ) {
959980 return " mmap" ;
960981 }
@@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
10011022 if (params.no_kv_offload .size () > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload ) {
10021023 fields.emplace_back (" no_kv_offload" );
10031024 }
1025+ if (params.flash_attn .size () > 1 || params.flash_attn != cmd_params_defaults.flash_attn ) {
1026+ fields.emplace_back (" flash_attn" );
1027+ }
10041028 if (params.tensor_split .size () > 1 || params.tensor_split != cmd_params_defaults.tensor_split ) {
10051029 fields.emplace_back (" tensor_split" );
10061030 }
0 commit comments