Skip to content

Commit 521c1e0

Browse files
mitmulCISC
andauthored
Apply suggestion from @CISC
Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 5231e4f commit 521c1e0

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

src/llama-model.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15636,21 +15636,14 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1563615636
const int64_t k_offset = n_embd_head_q * n_head;
1563715637
const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
1563815638

15639-
ggml_tensor * Qcur = ggml_view_2d(ctx0, qkv, n_embd_head_q * n_head, n_tokens, qkv->nb[1], q_offset * ggml_element_size(qkv));
15640-
ggml_tensor * Kcur = ggml_view_2d(ctx0, qkv, n_embd_head_k * n_head_kv, n_tokens, qkv->nb[1], k_offset * ggml_element_size(qkv));
15641-
ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv));
15642-
15643-
// make tensors contiguous before reshape
15644-
Qcur = ggml_cont(ctx0, Qcur);
15645-
Kcur = ggml_cont(ctx0, Kcur);
15646-
Vcur = ggml_cont(ctx0, Vcur);
15639+
ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
15640+
ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
15641+
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)));
1564715642

1564815643
cb(Qcur, "Qcur", il);
1564915644
cb(Kcur, "Kcur", il);
1565015645
cb(Vcur, "Vcur", il);
1565115646

15652-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_q, n_head, n_tokens);
15653-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
1565415647
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
1565515648

1565615649
Qcur = build_norm(Qcur, model.layers[il].wq, NULL, LLM_NORM_RMS, il);

0 commit comments

Comments
 (0)