@@ -900,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
900900 params.cont_batching = true ;
901901 return true ;
902902 }
903+ if (arg == " -fa" || arg == " --flash-attn" ) {
904+ params.flash_attn = true ;
905+ return true ;
906+ }
903907 if (arg == " --color" ) {
904908 params.use_color = true ;
905909 return true ;
@@ -1836,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
18361840 cparams.cb_eval = params.cb_eval ;
18371841 cparams.cb_eval_user_data = params.cb_eval_user_data ;
18381842 cparams.offload_kqv = !params.no_kv_offload ;
1843+ cparams.flash_attn = params.flash_attn ;
18391844
18401845 cparams.type_k = kv_cache_type_from_str (params.cache_type_k );
18411846 cparams.type_v = kv_cache_type_from_str (params.cache_type_v );
@@ -2673,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
26732678 fprintf (stream, " seed: %u # default: -1 (random seed)\n " , params.seed );
26742679 fprintf (stream, " simple_io: %s # default: false\n " , params.simple_io ? " true" : " false" );
26752680 fprintf (stream, " cont_batching: %s # default: false\n " , params.cont_batching ? " true" : " false" );
2681+ fprintf (stream, " flash_attn: %s # default: false\n " , params.flash_attn ? " true" : " false" );
26762682 fprintf (stream, " temp: %f # default: 0.8\n " , sparams.temp );
26772683
26782684 const std::vector<float > tensor_split_vector (params.tensor_split , params.tensor_split + llama_max_devices ());
0 commit comments