Skip to content

Commit c12f73b

Browse files
ikawrakowsaood06sszymczyIwan Kawrakow
authored
Add optional MLA (#188)
* Deepseek MLA Optimizations Co-authored-by: Stanisław Szymczyk <[email protected]> * Make MLA optional * Remove some unnecessary copies in the MLA attention * Deepseek MLA Optimizations V2 (#195) * Avoid allocating MHA KV cache when MLA is turned on * Added missing gguf-py file * Added final optimizations Co-authored-by: Stanisław Szymczyk <[email protected]> * Make sure we do have wk_b and wv_b before enabling MLA --------- Co-authored-by: Stanisław Szymczyk <[email protected]> Co-authored-by: Iwan Kawrakow <[email protected]> * Use type_k and type_v to set the types of the MLA caches They were hard-coded at f16. On my Ryzen-7950X with native bf16 support I get a fairly significant PP performance boost with bf16 KV-cache: PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache. * Better gemm strategy when nth > nhead It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads (with or without MLA). Before this commit, when nth > nhead heads were processed sequentially with all nth threads participating in each matrix multiplication. Now we ind the gcd of nhead and nth and split threads into nth/gcd groups, each group processing nhead/gcd heads. --------- Co-authored-by: Saood Karim <[email protected]> Co-authored-by: Stanisław Szymczyk <[email protected]> Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent cae2b81 commit c12f73b

File tree

9 files changed

+380
-75
lines changed

9 files changed

+380
-75
lines changed

common/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
813813
params.flash_attn = true;
814814
return true;
815815
}
816+
if (arg == "-mla" || arg == "--mla-use") {
817+
params.mla_attn = true;
818+
return true;
819+
}
816820
if (arg == "-co" || arg == "--color") {
817821
params.use_color = true;
818822
return true;
@@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14521456
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
14531457
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
14541458
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
1459+
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" });
14551460
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
14561461
"in conversation mode, this will be used as system prompt\n"
14571462
"(default: '%s')", params.prompt.c_str() });
@@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
22832288
cparams.cb_eval_user_data = params.cb_eval_user_data;
22842289
cparams.offload_kqv = !params.no_kv_offload;
22852290
cparams.flash_attn = params.flash_attn;
2291+
cparams.mla_attn = params.mla_attn;
22862292

22872293
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
22882294
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
32803286
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
32813287
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
32823288
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
3289+
fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false");
32833290
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
32843291

32853292
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ struct gpt_params {
174174
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
175175
bool cont_batching = true; // insert new sequences for decoding on-the-fly
176176
bool flash_attn = false; // flash attention
177+
bool mla_attn = false; // MLA
177178

178179
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
179180
bool ignore_eos = false; // ignore generated EOS tokens

convert_hf_to_gguf.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,6 +3123,7 @@ def prepare_tensors(self):
31233123

31243124

31253125
@Model.register("DeepseekV2ForCausalLM")
3126+
@Model.register("DeepseekV3ForCausalLM")
31263127
class DeepseekV2Model(Model):
31273128
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
31283129

@@ -3144,6 +3145,15 @@ def set_gguf_parameters(self):
31443145
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
31453146
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
31463147
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
3148+
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
3149+
3150+
if hparams["scoring_func"] == "sigmoid":
3151+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
3152+
elif hparams["scoring_func"] == "softmax":
3153+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
3154+
else:
3155+
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
3156+
31473157
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
31483158

31493159
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
@@ -3156,6 +3166,17 @@ def set_gguf_parameters(self):
31563166
_experts: list[dict[str, Tensor]] | None = None
31573167

31583168
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3169+
# rename e_score_correction_bias tensors
3170+
if name.endswith("e_score_correction_bias"):
3171+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
3172+
3173+
# skip Multi-Token Prediction (MTP) layers
3174+
block_count = self.hparams["num_hidden_layers"]
3175+
match = re.match(r"model.layers.(\d+)", name)
3176+
if match and int(match.group(1)) >= block_count:
3177+
return []
3178+
3179+
31593180
# process the experts separately
31603181
if name.find("mlp.experts") != -1:
31613182
n_experts = self.hparams["n_routed_experts"]
@@ -3188,6 +3209,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
31883209
return tensors
31893210
else:
31903211
return []
3212+
if name.endswith("kv_b_proj.weight"):
3213+
name_kb = name.replace("kv_b_proj", "k_b_proj")
3214+
name_vb = name.replace("kv_b_proj", "v_b_proj")
3215+
3216+
n_head_kv = self.hparams["num_key_value_heads"]
3217+
v_head_dim = self.hparams["v_head_dim"]
3218+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
3219+
3220+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
3221+
3222+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
3223+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
3224+
k_b = k_b.transpose(1, 2)
3225+
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
3226+
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
3227+
3228+
return [
3229+
(self.map_tensor_name(name), data_torch),
3230+
(self.map_tensor_name(name_kb), k_b),
3231+
(self.map_tensor_name(name_vb), v_b)
3232+
]
31913233

31923234
return [(self.map_tensor_name(name), data_torch)]
31933235

examples/llama-bench/llama-bench.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ struct cmd_params {
232232
std::vector<int> main_gpu;
233233
std::vector<bool> no_kv_offload;
234234
std::vector<bool> flash_attn;
235+
std::vector<bool> mla_attn;
235236
std::vector<std::vector<float>> tensor_split;
236237
std::vector<bool> use_mmap;
237238
std::vector<bool> embeddings;
@@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = {
261262
/* main_gpu */ {0},
262263
/* no_kv_offload */ {false},
263264
/* flash_attn */ {false},
265+
/* mla_attn */ {false},
264266
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
265267
/* use_mmap */ {true},
266268
/* embeddings */ {false},
@@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) {
294296
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
295297
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
296298
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
299+
printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
297300
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
298301
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
299302
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
526529
}
527530
auto p = string_split<bool>(argv[i], split_delim);
528531
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
532+
} else if (arg == "-mla" || arg == "--mla-attn") {
533+
if (++i >= argc) {
534+
invalid_param = true;
535+
break;
536+
}
537+
auto p = string_split<bool>(argv[i], split_delim);
538+
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
529539
} else if (arg == "-mmp" || arg == "--mmap") {
530540
if (++i >= argc) {
531541
invalid_param = true;
@@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
621631
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
622632
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
623633
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
634+
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
624635
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
625636
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
626637
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -656,6 +667,7 @@ struct cmd_params_instance {
656667
int main_gpu;
657668
bool no_kv_offload;
658669
bool flash_attn;
670+
bool mla_attn;
659671
std::vector<float> tensor_split;
660672
bool use_mmap;
661673
bool embeddings;
@@ -698,6 +710,7 @@ struct cmd_params_instance {
698710
cparams.type_v = type_v;
699711
cparams.offload_kqv = !no_kv_offload;
700712
cparams.flash_attn = flash_attn;
713+
cparams.mla_attn = mla_attn;
701714
cparams.embeddings = embeddings;
702715

703716
return cparams;
@@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
722735
for (const auto & tv : params.type_v)
723736
for (const auto & nkvo : params.no_kv_offload)
724737
for (const auto & fa : params.flash_attn)
738+
for (const auto & mla : params.mla_attn)
725739
for (const auto & nt : params.n_threads) {
726740
for (const auto & n_prompt : params.n_prompt) {
727741
if (n_prompt == 0) {
@@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
743757
/* .main_gpu = */ mg,
744758
/* .no_kv_offload= */ nkvo,
745759
/* .flash_attn = */ fa,
760+
/* .mla_attn = */ mla,
746761
/* .tensor_split = */ ts,
747762
/* .use_mmap = */ mmp,
748763
/* .embeddings = */ embd,
@@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
771786
/* .main_gpu = */ mg,
772787
/* .no_kv_offload= */ nkvo,
773788
/* .flash_attn = */ fa,
789+
/* .mla_attn = */ mla,
774790
/* .tensor_split = */ ts,
775791
/* .use_mmap = */ mmp,
776792
/* .embeddings = */ embd,
@@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
799815
/* .main_gpu = */ mg,
800816
/* .no_kv_offload= */ nkvo,
801817
/* .flash_attn = */ fa,
818+
/* .mla_attn = */ mla,
802819
/* .tensor_split = */ ts,
803820
/* .use_mmap = */ mmp,
804821
/* .embeddings = */ embd,
@@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
827844
/* .main_gpu = */ mg,
828845
/* .no_kv_offload= */ nkvo,
829846
/* .flash_attn = */ fa,
847+
/* .mla_attn = */ mla,
830848
/* .tensor_split = */ ts,
831849
/* .use_mmap = */ mmp,
832850
/* .embeddings = */ embd,
@@ -866,6 +884,7 @@ struct test {
866884
int main_gpu;
867885
bool no_kv_offload;
868886
bool flash_attn;
887+
bool mla_attn;
869888
std::vector<float> tensor_split;
870889
bool use_mmap;
871890
bool embeddings;
@@ -895,6 +914,7 @@ struct test {
895914
main_gpu = inst.main_gpu;
896915
no_kv_offload = inst.no_kv_offload;
897916
flash_attn = inst.flash_attn;
917+
mla_attn = inst.mla_attn;
898918
tensor_split = inst.tensor_split;
899919
use_mmap = inst.use_mmap;
900920
embeddings = inst.embeddings;
@@ -988,7 +1008,7 @@ struct test {
9881008
"n_batch", "n_ubatch",
9891009
"n_threads", "type_k", "type_v",
9901010
"n_gpu_layers", "split_mode",
991-
"main_gpu", "no_kv_offload", "flash_attn",
1011+
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
9921012
"tensor_split", "use_mmap", "embeddings", "repack",
9931013
"n_prompt", "n_gen", "test_time",
9941014
"avg_ns", "stddev_ns",
@@ -1010,7 +1030,7 @@ struct test {
10101030
}
10111031
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
10121032
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
1013-
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
1033+
field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
10141034
return BOOL;
10151035
}
10161036
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1044,7 +1064,7 @@ struct test {
10441064
std::to_string(n_batch), std::to_string(n_ubatch),
10451065
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
10461066
std::to_string(n_gpu_layers), split_mode_str(split_mode),
1047-
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
1067+
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
10481068
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
10491069
std::to_string(n_prompt), std::to_string(n_gen), test_time,
10501070
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1208,6 +1228,9 @@ struct markdown_printer : public printer {
12081228
if (field == "flash_attn") {
12091229
return 2;
12101230
}
1231+
if (field == "mla_attn") {
1232+
return 3;
1233+
}
12111234
if (field == "use_mmap") {
12121235
return 4;
12131236
}
@@ -1242,6 +1265,9 @@ struct markdown_printer : public printer {
12421265
if (field == "flash_attn") {
12431266
return "fa";
12441267
}
1268+
if (field == "mla_attn") {
1269+
return "mla";
1270+
}
12451271
if (field == "use_mmap") {
12461272
return "mmap";
12471273
}
@@ -1294,6 +1320,9 @@ struct markdown_printer : public printer {
12941320
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
12951321
fields.emplace_back("flash_attn");
12961322
}
1323+
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
1324+
fields.emplace_back("mla_attn");
1325+
}
12971326
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
12981327
fields.emplace_back("tensor_split");
12991328
}

ggml/src/ggml.c

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14064,31 +14064,22 @@ static void ggml_compute_forward_mul_mat(
1406414064
#endif
1406514065

1406614066
#if GGML_USE_IQK_MULMAT
14067-
if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
14067+
if (dst->type == GGML_TYPE_F32) {
14068+
int gcd = simple_gcd(ne12*ne13, nth);
1406814069
int counter = 0;
1406914070
for (int64_t i13 = 0; i13 < ne13; i13++) {
1407014071
for (int64_t i12 = 0; i12 < ne12; i12++) {
14071-
if (counter++ % nth == ith) {
14072+
if ((counter++ % gcd) == (ith%gcd)) {
1407214073
if (!iqk_mul_mat(ne01, ne11, ne00,
1407314074
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
1407414075
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
1407514076
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
14076-
0, 1)) goto IQK_MulMat_Not_Available1;
14077+
ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1;
1407714078
}
1407814079
}
1407914080
}
1408014081
return;
1408114082
}
14082-
if (dst->type == GGML_TYPE_F32) {
14083-
for (int64_t i13 = 0; i13 < ne13; i13++)
14084-
for (int64_t i12 = 0; i12 < ne12; i12++)
14085-
if (!iqk_mul_mat(ne01, ne11, ne00,
14086-
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
14087-
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
14088-
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
14089-
ith, nth)) goto IQK_MulMat_Not_Available1;
14090-
return;
14091-
}
1409214083
IQK_MulMat_Not_Available1:;
1409314084
#endif
1409414085

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum):
274274
ATTN_Q_B = auto()
275275
ATTN_KV_A_MQA = auto()
276276
ATTN_KV_B = auto()
277+
ATTN_K_B = auto()
278+
ATTN_V_B = auto()
277279
ATTN_Q_A_NORM = auto()
278280
ATTN_KV_A_NORM = auto()
279281
FFN_SUB_NORM = auto()
@@ -403,6 +405,8 @@ class MODEL_TENSOR(IntEnum):
403405
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
404406
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
405407
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
408+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
409+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
406410
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
407411
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
408412
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -967,6 +971,8 @@ class MODEL_TENSOR(IntEnum):
967971
MODEL_TENSOR.ATTN_Q_B,
968972
MODEL_TENSOR.ATTN_KV_A_MQA,
969973
MODEL_TENSOR.ATTN_KV_B,
974+
MODEL_TENSOR.ATTN_K_B,
975+
MODEL_TENSOR.ATTN_V_B,
970976
MODEL_TENSOR.ATTN_Q_A_NORM,
971977
MODEL_TENSOR.ATTN_KV_A_NORM,
972978
MODEL_TENSOR.ATTN_OUT,

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,14 @@ class TensorNameMap:
446446
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
447447
),
448448

449+
MODEL_TENSOR.ATTN_K_B: (
450+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
451+
),
452+
453+
MODEL_TENSOR.ATTN_V_B: (
454+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
455+
),
456+
449457
MODEL_TENSOR.ATTN_Q_A_NORM: (
450458
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
451459
),

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ extern "C" {
374374
bool embeddings; // if true, extract embeddings (together with logits)
375375
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
376376
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
377+
bool mla_attn; // whether to use MLA attention [EXPERIMENTAL]
377378

378379
// Abort callback
379380
// if it returns true, execution of llama_decode() will be aborted

0 commit comments

Comments
 (0)