Skip to content

Commit 37c4fbd

Browse files
author
Iwan Kawrakow
committed
Make MLA optional
1 parent 35246c4 commit 37c4fbd

File tree

5 files changed

+215
-98
lines changed

5 files changed

+215
-98
lines changed

common/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
813813
params.flash_attn = true;
814814
return true;
815815
}
816+
if (arg == "-mla" || arg == "--mla-use") {
817+
params.mla_attn = true;
818+
return true;
819+
}
816820
if (arg == "-co" || arg == "--color") {
817821
params.use_color = true;
818822
return true;
@@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14521456
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
14531457
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
14541458
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
1459+
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" });
14551460
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
14561461
"in conversation mode, this will be used as system prompt\n"
14571462
"(default: '%s')", params.prompt.c_str() });
@@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
22832288
cparams.cb_eval_user_data = params.cb_eval_user_data;
22842289
cparams.offload_kqv = !params.no_kv_offload;
22852290
cparams.flash_attn = params.flash_attn;
2291+
cparams.mla_attn = params.mla_attn;
22862292

22872293
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
22882294
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32803286
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
32813287
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
32823288
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
3289+
fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false");
32833290
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
32843291

32853292
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ struct gpt_params {
174174
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
175175
bool cont_batching = true; // insert new sequences for decoding on-the-fly
176176
bool flash_attn = false; // flash attention
177+
bool mla_attn = false; // MLA
177178

178179
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
179180
bool ignore_eos = false; // ignore generated EOS tokens

examples/llama-bench/llama-bench.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ struct cmd_params {
232232
std::vector<int> main_gpu;
233233
std::vector<bool> no_kv_offload;
234234
std::vector<bool> flash_attn;
235+
std::vector<bool> mla_attn;
235236
std::vector<std::vector<float>> tensor_split;
236237
std::vector<bool> use_mmap;
237238
std::vector<bool> embeddings;
@@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = {
261262
/* main_gpu */ {0},
262263
/* no_kv_offload */ {false},
263264
/* flash_attn */ {false},
265+
/* mla_attn */ {false},
264266
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
265267
/* use_mmap */ {true},
266268
/* embeddings */ {false},
@@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) {
294296
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
295297
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
296298
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
299+
printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
297300
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
298301
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
299302
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
526529
}
527530
auto p = string_split<bool>(argv[i], split_delim);
528531
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
532+
} else if (arg == "-mla" || arg == "--mla-attn") {
533+
if (++i >= argc) {
534+
invalid_param = true;
535+
break;
536+
}
537+
auto p = string_split<bool>(argv[i], split_delim);
538+
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
529539
} else if (arg == "-mmp" || arg == "--mmap") {
530540
if (++i >= argc) {
531541
invalid_param = true;
@@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
621631
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
622632
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
623633
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
634+
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
624635
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
625636
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
626637
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -656,6 +667,7 @@ struct cmd_params_instance {
656667
int main_gpu;
657668
bool no_kv_offload;
658669
bool flash_attn;
670+
bool mla_attn;
659671
std::vector<float> tensor_split;
660672
bool use_mmap;
661673
bool embeddings;
@@ -698,6 +710,7 @@ struct cmd_params_instance {
698710
cparams.type_v = type_v;
699711
cparams.offload_kqv = !no_kv_offload;
700712
cparams.flash_attn = flash_attn;
713+
cparams.mla_attn = mla_attn;
701714
cparams.embeddings = embeddings;
702715

703716
return cparams;
@@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
722735
for (const auto & tv : params.type_v)
723736
for (const auto & nkvo : params.no_kv_offload)
724737
for (const auto & fa : params.flash_attn)
738+
for (const auto & mla : params.mla_attn)
725739
for (const auto & nt : params.n_threads) {
726740
for (const auto & n_prompt : params.n_prompt) {
727741
if (n_prompt == 0) {
@@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
743757
/* .main_gpu = */ mg,
744758
/* .no_kv_offload= */ nkvo,
745759
/* .flash_attn = */ fa,
760+
/* .mla_attn = */ mla,
746761
/* .tensor_split = */ ts,
747762
/* .use_mmap = */ mmp,
748763
/* .embeddings = */ embd,
@@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
771786
/* .main_gpu = */ mg,
772787
/* .no_kv_offload= */ nkvo,
773788
/* .flash_attn = */ fa,
789+
/* .mla_attn = */ mla,
774790
/* .tensor_split = */ ts,
775791
/* .use_mmap = */ mmp,
776792
/* .embeddings = */ embd,
@@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
799815
/* .main_gpu = */ mg,
800816
/* .no_kv_offload= */ nkvo,
801817
/* .flash_attn = */ fa,
818+
/* .mla_attn = */ mla,
802819
/* .tensor_split = */ ts,
803820
/* .use_mmap = */ mmp,
804821
/* .embeddings = */ embd,
@@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
827844
/* .main_gpu = */ mg,
828845
/* .no_kv_offload= */ nkvo,
829846
/* .flash_attn = */ fa,
847+
/* .mla_attn = */ mla,
830848
/* .tensor_split = */ ts,
831849
/* .use_mmap = */ mmp,
832850
/* .embeddings = */ embd,
@@ -866,6 +884,7 @@ struct test {
866884
int main_gpu;
867885
bool no_kv_offload;
868886
bool flash_attn;
887+
bool mla_attn;
869888
std::vector<float> tensor_split;
870889
bool use_mmap;
871890
bool embeddings;
@@ -895,6 +914,7 @@ struct test {
895914
main_gpu = inst.main_gpu;
896915
no_kv_offload = inst.no_kv_offload;
897916
flash_attn = inst.flash_attn;
917+
mla_attn = inst.mla_attn;
898918
tensor_split = inst.tensor_split;
899919
use_mmap = inst.use_mmap;
900920
embeddings = inst.embeddings;
@@ -988,7 +1008,7 @@ struct test {
9881008
"n_batch", "n_ubatch",
9891009
"n_threads", "type_k", "type_v",
9901010
"n_gpu_layers", "split_mode",
991-
"main_gpu", "no_kv_offload", "flash_attn",
1011+
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
9921012
"tensor_split", "use_mmap", "embeddings", "repack",
9931013
"n_prompt", "n_gen", "test_time",
9941014
"avg_ns", "stddev_ns",
@@ -1010,7 +1030,7 @@ struct test {
10101030
}
10111031
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
10121032
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
1013-
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
1033+
field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
10141034
return BOOL;
10151035
}
10161036
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1044,7 +1064,7 @@ struct test {
10441064
std::to_string(n_batch), std::to_string(n_ubatch),
10451065
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
10461066
std::to_string(n_gpu_layers), split_mode_str(split_mode),
1047-
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
1067+
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
10481068
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
10491069
std::to_string(n_prompt), std::to_string(n_gen), test_time,
10501070
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1208,6 +1228,9 @@ struct markdown_printer : public printer {
12081228
if (field == "flash_attn") {
12091229
return 2;
12101230
}
1231+
if (field == "mla_attn") {
1232+
return 3;
1233+
}
12111234
if (field == "use_mmap") {
12121235
return 4;
12131236
}
@@ -1242,6 +1265,9 @@ struct markdown_printer : public printer {
12421265
if (field == "flash_attn") {
12431266
return "fa";
12441267
}
1268+
if (field == "mla_attn") {
1269+
return "mla";
1270+
}
12451271
if (field == "use_mmap") {
12461272
return "mmap";
12471273
}
@@ -1294,6 +1320,9 @@ struct markdown_printer : public printer {
12941320
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
12951321
fields.emplace_back("flash_attn");
12961322
}
1323+
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
1324+
fields.emplace_back("mla_attn");
1325+
}
12971326
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
12981327
fields.emplace_back("tensor_split");
12991328
}

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ extern "C" {
374374
bool embeddings; // if true, extract embeddings (together with logits)
375375
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
376376
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
377+
bool mla_attn; // whether to use MLA attention [EXPERIMENTAL]
377378

378379
// Abort callback
379380
// if it returns true, execution of llama_decode() will be aborted

0 commit comments

Comments
 (0)