Skip to content

Commit 2d16e41

Browse files
saood06sszymczy
authored andcommitted
Deepseek MLA Optimizations
Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 7f61b30 commit 2d16e41

File tree

3 files changed

+190
-29
lines changed

3 files changed

+190
-29
lines changed

convert_hf_to_gguf.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,6 +3123,7 @@ def prepare_tensors(self):
31233123

31243124

31253125
@Model.register("DeepseekV2ForCausalLM")
3126+
@Model.register("DeepseekV3ForCausalLM")
31263127
class DeepseekV2Model(Model):
31273128
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
31283129

@@ -3144,6 +3145,15 @@ def set_gguf_parameters(self):
31443145
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
31453146
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
31463147
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
3148+
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
3149+
3150+
if hparams["scoring_func"] == "sigmoid":
3151+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
3152+
elif hparams["scoring_func"] == "softmax":
3153+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
3154+
else:
3155+
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
3156+
31473157
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
31483158

31493159
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
@@ -3156,6 +3166,17 @@ def set_gguf_parameters(self):
31563166
_experts: list[dict[str, Tensor]] | None = None
31573167

31583168
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3169+
# rename e_score_correction_bias tensors
3170+
if name.endswith("e_score_correction_bias"):
3171+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
3172+
3173+
# skip Multi-Token Prediction (MTP) layers
3174+
block_count = self.hparams["num_hidden_layers"]
3175+
match = re.match(r"model.layers.(\d+)", name)
3176+
if match and int(match.group(1)) >= block_count:
3177+
return []
3178+
3179+
31593180
# process the experts separately
31603181
if name.find("mlp.experts") != -1:
31613182
n_experts = self.hparams["n_routed_experts"]
@@ -3188,6 +3209,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
31883209
return tensors
31893210
else:
31903211
return []
3212+
if name.endswith("kv_b_proj.weight"):
3213+
name_kb = name.replace("kv_b_proj", "k_b_proj")
3214+
name_vb = name.replace("kv_b_proj", "v_b_proj")
3215+
3216+
n_head_kv = self.hparams["num_key_value_heads"]
3217+
v_head_dim = self.hparams["v_head_dim"]
3218+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
3219+
3220+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
3221+
3222+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
3223+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
3224+
k_b = k_b.transpose(1, 2)
3225+
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
3226+
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
3227+
3228+
return [
3229+
(self.map_tensor_name(name), data_torch),
3230+
(self.map_tensor_name(name_kb), k_b),
3231+
(self.map_tensor_name(name_vb), v_b)
3232+
]
31913233

31923234
return [(self.map_tensor_name(name), data_torch)]
31933235

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum):
274274
ATTN_Q_B = auto()
275275
ATTN_KV_A_MQA = auto()
276276
ATTN_KV_B = auto()
277+
ATTN_K_B = auto()
278+
ATTN_V_B = auto()
277279
ATTN_Q_A_NORM = auto()
278280
ATTN_KV_A_NORM = auto()
279281
FFN_SUB_NORM = auto()
@@ -403,6 +405,8 @@ class MODEL_TENSOR(IntEnum):
403405
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
404406
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
405407
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
408+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
409+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
406410
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
407411
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
408412
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -967,6 +971,8 @@ class MODEL_TENSOR(IntEnum):
967971
MODEL_TENSOR.ATTN_Q_B,
968972
MODEL_TENSOR.ATTN_KV_A_MQA,
969973
MODEL_TENSOR.ATTN_KV_B,
974+
MODEL_TENSOR.ATTN_K_B,
975+
MODEL_TENSOR.ATTN_V_B,
970976
MODEL_TENSOR.ATTN_Q_A_NORM,
971977
MODEL_TENSOR.ATTN_KV_A_NORM,
972978
MODEL_TENSOR.ATTN_OUT,

src/llama.cpp

Lines changed: 142 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ enum llm_tensor {
539539
LLM_TENSOR_ATTN_Q_B,
540540
LLM_TENSOR_ATTN_KV_A_MQA,
541541
LLM_TENSOR_ATTN_KV_B,
542+
LLM_TENSOR_ATTN_K_B,
543+
LLM_TENSOR_ATTN_V_B,
542544
LLM_TENSOR_ATTN_Q_A_NORM,
543545
LLM_TENSOR_ATTN_KV_A_NORM,
544546
LLM_TENSOR_ATTN_SUB_NORM,
@@ -1203,6 +1205,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
12031205
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
12041206
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
12051207
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1208+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1209+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
12061210
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
12071211
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
12081212
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -2541,6 +2545,8 @@ struct llama_layer {
25412545
struct ggml_tensor * wq_b;
25422546
struct ggml_tensor * wkv_a_mqa;
25432547
struct ggml_tensor * wkv_b;
2548+
struct ggml_tensor * wk_b;
2549+
struct ggml_tensor * wv_b;
25442550
struct ggml_tensor * wq_cross;
25452551
struct ggml_tensor * wk_cross;
25462552
struct ggml_tensor * wv_cross;
@@ -2669,11 +2675,19 @@ struct llama_kv_cache {
26692675
ggml_type type_k = GGML_TYPE_F16;
26702676
ggml_type type_v = GGML_TYPE_F16;
26712677

2678+
ggml_type type_kr = GGML_TYPE_F16;
2679+
ggml_type type_kv = GGML_TYPE_F16;
2680+
26722681
std::vector<llama_kv_cell> cells;
26732682

26742683
std::vector<struct ggml_tensor *> k_l; // per layer
26752684
std::vector<struct ggml_tensor *> v_l;
26762685

2686+
// DeepSeek MLA
2687+
std::vector<struct ggml_tensor *> kr_l; // per layer
2688+
std::vector<struct ggml_tensor *> kv_l;
2689+
std::vector<struct ggml_tensor *> kvt_l;
2690+
26772691
std::vector<struct ggml_context *> ctxs;
26782692
std::vector<ggml_backend_buffer_t> bufs;
26792693

@@ -3132,7 +3146,7 @@ static bool llama_kv_cache_init(
31323146
for (auto & it : buft_layer_count) {
31333147
int n_layers = it.second;
31343148
struct ggml_init_params params = {
3135-
/*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(),
3149+
/*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(),
31363150
/*.mem_buffer =*/ NULL,
31373151
/*.no_alloc =*/ true,
31383152
};
@@ -3148,6 +3162,11 @@ static bool llama_kv_cache_init(
31483162
cache.k_l.reserve(n_layer);
31493163
cache.v_l.reserve(n_layer);
31503164

3165+
// DeepSeek MLA
3166+
cache.kr_l.reserve(n_layer);
3167+
cache.kv_l.reserve(n_layer);
3168+
cache.kvt_l.reserve(n_layer);
3169+
31513170
for (int i = 0; i < (int) n_layer; i++) {
31523171
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
31533172
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -3159,6 +3178,21 @@ static bool llama_kv_cache_init(
31593178
ggml_format_name(v, "cache_v_l%d", i);
31603179
cache.k_l.push_back(k);
31613180
cache.v_l.push_back(v);
3181+
3182+
3183+
// DeepSeek MLA
3184+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
3185+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
3186+
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
3187+
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
3188+
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
3189+
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
3190+
ggml_format_name(kr, "cache_kr_l%d", i);
3191+
ggml_format_name(kv, "cache_kv_l%d", i);
3192+
ggml_format_name(kvt, "cache_kvt_l%d", i);
3193+
cache.kr_l.push_back(kr);
3194+
cache.kv_l.push_back(kv);
3195+
cache.kvt_l.push_back(kvt);
31623196
}
31633197

31643198
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -7644,6 +7678,8 @@ static bool llm_load_tensors(
76447678

76457679
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
76467680
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
7681+
layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
7682+
layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
76477683
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
76487684

76497685
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -13396,31 +13432,31 @@ struct llm_build_context {
1339613432
LLM_NORM_RMS, cb, il);
1339713433
cb(kv_compressed, "kv_compressed", il);
1339813434

13399-
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
13400-
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
13401-
cb(kv, "kv", il);
13435+
struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
13436+
cb(kv_cache_view, "kv_cache_view", il);
1340213437

13403-
// split into {n_head * n_embd_head_qk_nope, n_tokens}
13404-
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
13405-
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
13406-
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
13407-
0);
13408-
cb(k_nope, "k_nope", il);
13438+
// note: storing c^KV in the KV cache
13439+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
1340913440

13410-
// and {n_head * n_embd_head_v, n_tokens}
13411-
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
13412-
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
13413-
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
13414-
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
13415-
cb(v_states, "v_states", il);
13441+
struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
13442+
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
1341613443

13417-
v_states = ggml_cont(ctx0, v_states);
13418-
cb(v_states, "v_states", il);
13444+
// note: storing transposed c^KV in the transposed KV cache
13445+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
1341913446

13420-
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
13421-
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
13422-
0);
13423-
cb(v_states, "v_states", il);
13447+
struct ggml_tensor * kv_cache =
13448+
ggml_view_2d(ctx0, kv_self.kv_l[il],
13449+
kv_lora_rank, n_kv,
13450+
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
13451+
0);
13452+
cb(kv_cache, "kv_cache", il);
13453+
13454+
struct ggml_tensor * kv_cache_trans =
13455+
ggml_view_2d(ctx0, kv_self.kvt_l[il],
13456+
n_kv, kv_lora_rank,
13457+
ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
13458+
0);
13459+
cb(kv_cache_trans, "kv_cache_trans", il);
1342413460

1342513461
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
1342613462
q_pe = ggml_rope_ext(
@@ -13439,15 +13475,74 @@ struct llm_build_context {
1343913475
);
1344013476
cb(k_pe, "k_pe", il);
1344113477

13442-
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
13443-
cb(q_states, "q_states", il);
13478+
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
13479+
cb(kr_cache_view, "kr_cache_view", il);
1344413480

13445-
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
13446-
cb(k_states, "k_states", il);
13481+
// note: storing RoPE-ed version of K^R in the KV cache
13482+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
1344713483

13448-
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
13449-
model.layers[il].wo, NULL,
13450-
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
13484+
struct ggml_tensor * kr_cache =
13485+
ggml_view_2d(ctx0, kv_self.kr_l[il],
13486+
n_embd_head_qk_rope, n_kv,
13487+
ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
13488+
0);
13489+
cb(kr_cache, "kr_cache", il);
13490+
13491+
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
13492+
cb(wk_b, "wk_b", il);
13493+
13494+
struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
13495+
cb(q_nope_perm, "q_nope_perm", il);
13496+
13497+
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
13498+
cb(q_nope2, "q_nope2", il);
13499+
13500+
struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
13501+
cb(q_nope2_perm, "q_nope2_perm", il);
13502+
13503+
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
13504+
cb(kq_nope, "kq_nope", il);
13505+
13506+
struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
13507+
cb(q_pe_perm, "q_pe_perm", il);
13508+
13509+
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
13510+
cb(kq_pe, "kq_pe", il);
13511+
13512+
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
13513+
cb(kq, "kq", il);
13514+
13515+
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
13516+
cb(kq, "kq_perm", il);
13517+
13518+
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
13519+
cb(kq, "kq_soft_max_ext", il);
13520+
13521+
struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3);
13522+
cb(kq_perm, "kq_soft_max_ext_perm", il);
13523+
13524+
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm);
13525+
cb(kqv_compressed, "kqv_compressed", il);
13526+
13527+
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
13528+
cb(kqv_compressed, "kqv_compressed_perm", il);
13529+
13530+
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
13531+
cb(wv_b, "wv_b", il);
13532+
13533+
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
13534+
cb(kqv, "kqv", il);
13535+
13536+
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
13537+
cb(kqv, "kqv_perm", il);
13538+
13539+
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
13540+
cb(cur, "kqv_2d", il);
13541+
13542+
ggml_build_forward_expand(gf, cur);
13543+
13544+
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
13545+
cb(cur, "kqv_out", il);
1345113546
}
1345213547

1345313548
if (il == n_layer - 1) {
@@ -17853,6 +17948,24 @@ struct llama_context * llama_new_context_with_model(
1785317948
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1785417949
}
1785517950

17951+
{
17952+
size_t memory_size_kr = 0;
17953+
size_t memory_size_kv = 0;
17954+
17955+
for (auto & kr : ctx->kv_self.kr_l) {
17956+
memory_size_kr += ggml_nbytes(kr);
17957+
}
17958+
17959+
for (auto & kv : ctx->kv_self.kv_l) {
17960+
memory_size_kv += ggml_nbytes(kv);
17961+
}
17962+
17963+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
17964+
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
17965+
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
17966+
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
17967+
}
17968+
1785617969
// graph outputs buffer
1785717970
{
1785817971
// resized during inference when a batch uses more outputs

0 commit comments

Comments
 (0)