Skip to content

Commit 599ce84

Browse files
committed
llama : flash_attn cparam + fix defrag
1 parent 2c41180 commit 599ce84

File tree

4 files changed

+194
-159
lines changed

4 files changed

+194
-159
lines changed

common/common.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
900900
params.cont_batching = true;
901901
return true;
902902
}
903+
if (arg == "-fa" || arg == "--flash-attn") {
904+
params.flash_attn = true;
905+
return true;
906+
}
903907
if (arg == "--color") {
904908
params.use_color = true;
905909
return true;
@@ -1836,6 +1840,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
18361840
cparams.cb_eval = params.cb_eval;
18371841
cparams.cb_eval_user_data = params.cb_eval_user_data;
18381842
cparams.offload_kqv = !params.no_kv_offload;
1843+
cparams.flash_attn = params.flash_attn;
18391844

18401845
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
18411846
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -2673,6 +2678,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
26732678
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
26742679
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
26752680
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
2681+
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
26762682
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
26772683

26782684
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct gpt_params {
148148
bool multiline_input = false; // reverse the usage of `\`
149149
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
150150
bool cont_batching = true; // insert new sequences for decoding on-the-fly
151+
bool flash_attn = false; // flash attention
151152

152153
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
153154
bool ignore_eos = false; // ignore generated EOS tokens

0 commit comments

Comments
 (0)