Skip to content

Commit 85c7986

Browse files
committed
Fix cache
1 parent 52b2da6 commit 85c7986

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

src/llama-model.cpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15532,20 +15532,17 @@ struct llm_build_lfm2 : public llm_graph_context {
1553215532

1553315533
ggml_tensor * inp_pos = build_inp_pos();
1553415534

15535-
auto *inp = build_inp_mem_hybrid();
15535+
auto *inp_hybrid = build_inp_mem_hybrid();
1553615536
ggml_tensor * inp_out_ids = build_inp_out_ids();
1553715537

15538-
// add s_copy to graph
15539-
ggml_build_forward_expand(gf, inp->s_copy);
15540-
1554115538
for (int il = 0; il < n_layer; ++il) {
1554215539
auto *prev_cur = cur;
1554315540
cur = lfm2_rms_norm(cur, model.layers[il].attn_norm);
1554415541
cb(cur, "model.layers.{}.operator_norm", il);
1554515542

1554615543
cur = hparams.is_recurrent(il) ?
15547-
build_shortconv_block(gf, cur, il) :
15548-
build_attn_block(gf, cur, inp_pos, inp, il) ;
15544+
build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) :
15545+
build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ;
1554915546

1555015547
if (il == n_layer - 1 && inp_out_ids) {
1555115548
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
@@ -15589,12 +15586,11 @@ struct llm_build_lfm2 : public llm_graph_context {
1558915586
return cur;
1559015587
}
1559115588

15592-
ggml_tensor *build_attn_block(
15593-
ggml_cgraph *gf,
15594-
ggml_tensor *cur,
15595-
ggml_tensor *inp_pos,
15596-
llm_graph_input_mem_hybrid *inp,
15597-
int il) const {
15589+
ggml_tensor *build_attn_block(ggml_cgraph *gf,
15590+
ggml_tensor *cur,
15591+
ggml_tensor *inp_pos,
15592+
llm_graph_input_attn_kv_unified *inp_attn,
15593+
int il) const {
1559815594
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
1559915595
auto const n_embd_head = hparams.n_embd_head_v;
1560015596
auto const n_head_kv = hparams.n_head_kv(il);
@@ -15628,19 +15624,18 @@ struct llm_build_lfm2 : public llm_graph_context {
1562815624
ext_factor, attn_factor, beta_fast, beta_slow
1562915625
);
1563015626

15631-
cur = build_attn(inp, gf,
15632-
model.layers[il].wo, NULL,
15627+
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,
1563315628
q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
1563415629

1563515630
cb(cur, "model.layers.{}.self_attn.out_proj", il);
1563615631

1563715632
return cur;
1563815633
}
1563915634

15640-
ggml_tensor * build_shortconv_block(
15641-
ggml_cgraph * gf,
15642-
ggml_tensor * cur,
15643-
int il) {
15635+
ggml_tensor * build_shortconv_block(ggml_cgraph * gf,
15636+
ggml_tensor * cur,
15637+
llm_graph_input_rs *inp_recr,
15638+
int il) {
1564415639
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1564515640

1564615641
auto *bcx = ggml_mul_mat(ctx0, model.layers[il].shortconv.in_proj, cur);
@@ -15656,9 +15651,10 @@ struct llm_build_lfm2 : public llm_graph_context {
1565615651
auto *bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
1565715652

1565815653
// read conv state directly, with build_rs generation is slower
15659-
const int64_t n_seqs = ubatch.n_seqs;
1566015654
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
15661-
auto *conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
15655+
const int64_t n_seqs = ubatch.n_seqs;
15656+
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
15657+
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
1566215658

1566315659
bx = ggml_concat(ctx0, conv, bx, 0);
1566415660
GGML_ASSERT(bx->ne[0] > conv->ne[0]);

0 commit comments

Comments
 (0)