Skip to content

Commit 8a887de

Browse files
committed
llama : prompt processing optimizations in DeepSeek V2
1 parent 8ff0991 commit 8a887de

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

src/llama.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6403,6 +6403,10 @@ struct llm_build_context {
64036403
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
64046404
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
64056405

6406+
// whether to use n_tokens as the matrix dimension during multiplication or n_head
6407+
// n_tokens is higher during prompt processing, this allows to optimize for this case
6408+
bool pp_opt = n_tokens > n_head;
6409+
64066410
for (int il = 0; il < n_layer; ++il) {
64076411
struct ggml_tensor * inpSA = inpL;
64086412

@@ -6535,35 +6539,45 @@ struct llm_build_context {
65356539
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
65366540
cb(q_nope2, "q_nope2", il);
65376541

6538-
struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
6539-
cb(q_nope2_perm, "q_nope2_perm", il);
6542+
if (!pp_opt) {
6543+
q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
6544+
cb(q_nope2, "q_nope2_perm", il);
6545+
}
65406546

6541-
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
6547+
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
65426548
cb(kq_nope, "kq_nope", il);
65436549

6544-
struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
6545-
cb(q_pe_perm, "q_pe_perm", il);
6550+
if (pp_opt) {
6551+
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
6552+
cb(q_pe, "q_pe_perm", il);
6553+
}
65466554

65476555
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
65486556
cb(kq_pe, "kq_pe", il);
65496557

65506558
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
65516559
cb(kq, "kq", il);
65526560

6553-
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
6554-
cb(kq, "kq_perm", il);
6561+
if (!pp_opt) {
6562+
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
6563+
cb(kq, "kq_perm", il);
6564+
}
65556565

65566566
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
65576567
cb(kq, "kq_soft_max_ext", il);
65586568

6559-
struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3);
6560-
cb(kq_perm, "kq_soft_max_ext_perm", il);
6569+
if (!pp_opt) {
6570+
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
6571+
cb(kq, "kq_soft_max_ext_perm", il);
6572+
}
65616573

6562-
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm);
6574+
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
65636575
cb(kqv_compressed, "kqv_compressed", il);
65646576

6565-
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
6566-
cb(kqv_compressed, "kqv_compressed_perm", il);
6577+
if (!pp_opt) {
6578+
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
6579+
cb(kqv_compressed, "kqv_compressed_perm", il);
6580+
}
65676581

65686582
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
65696583
cb(wv_b, "wv_b", il);

0 commit comments

Comments
 (0)