Skip to content

Commit 7654331

Browse files
committed
llama : avoid ggml_cont() is possible in DeepSeek V2 implementation
1 parent 8a887de commit 7654331

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

src/llama.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6533,10 +6533,10 @@ struct llm_build_context {
65336533
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
65346534
cb(wk_b, "wk_b", il);
65356535

6536-
struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
6537-
cb(q_nope_perm, "q_nope_perm", il);
6536+
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
6537+
cb(q_nope, "q_nope_perm", il);
65386538

6539-
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
6539+
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
65406540
cb(q_nope2, "q_nope2", il);
65416541

65426542
if (!pp_opt) {
@@ -6547,6 +6547,11 @@ struct llm_build_context {
65476547
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
65486548
cb(kq_nope, "kq_nope", il);
65496549

6550+
if (!pp_opt) {
6551+
kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
6552+
cb(kq_nope, "kq_nope_perm", il);
6553+
}
6554+
65506555
if (pp_opt) {
65516556
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
65526557
cb(q_pe, "q_pe_perm", il);
@@ -6555,14 +6560,14 @@ struct llm_build_context {
65556560
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
65566561
cb(kq_pe, "kq_pe", il);
65576562

6558-
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
6559-
cb(kq, "kq", il);
6560-
65616563
if (!pp_opt) {
6562-
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
6563-
cb(kq, "kq_perm", il);
6564+
kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
6565+
cb(kq_pe, "kq_pe_perm", il);
65646566
}
65656567

6568+
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
6569+
cb(kq, "kq", il);
6570+
65666571
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
65676572
cb(kq, "kq_soft_max_ext", il);
65686573

@@ -6575,7 +6580,7 @@ struct llm_build_context {
65756580
cb(kqv_compressed, "kqv_compressed", il);
65766581

65776582
if (!pp_opt) {
6578-
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
6583+
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 3, 1);
65796584
cb(kqv_compressed, "kqv_compressed_perm", il);
65806585
}
65816586

@@ -6585,8 +6590,10 @@ struct llm_build_context {
65856590
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
65866591
cb(kqv, "kqv", il);
65876592

6588-
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
6589-
cb(kqv, "kqv_perm", il);
6593+
if (pp_opt) {
6594+
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
6595+
cb(kqv, "kqv_perm", il);
6596+
}
65906597

65916598
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
65926599
cb(cur, "kqv_2d", il);

0 commit comments

Comments
 (0)