@@ -9086,6 +9086,7 @@ struct llm_build_mamba : public llm_graph_context {
90869086
90879087 ggml_tensor * state_copy = build_inp_s_copy();
90889088
9089+ const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
90899090 for (int il = 0; il < n_layer; ++il) {
90909091 // norm
90919092 cur = build_norm(inpL,
@@ -9094,9 +9095,9 @@ struct llm_build_mamba : public llm_graph_context {
90949095 cb(cur, "attn_norm", il);
90959096
90969097 if (use_mamba2) {
9097- cur = build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il);
9098+ cur = build_mamba2_layer(this, kv_state, gf, cur, state_copy, model, ubatch, il);
90989099 } else {
9099- cur = build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il);
9100+ cur = build_mamba_layer(this, kv_state, gf, cur, state_copy, model, ubatch, il);
91009101 }
91019102
91029103 if (il == n_layer - 1) {
@@ -9136,14 +9137,14 @@ struct llm_build_mamba : public llm_graph_context {
91369137 // static layer build function that enables other models to borrow this
91379138 // layer logic
91389139 static ggml_tensor * build_mamba_layer(
9139- const llm_graph_context * self,
9140- ggml_cgraph * gf ,
9141- ggml_tensor * cur ,
9142- ggml_tensor * state_copy ,
9143- const llama_model & model ,
9144- const llama_ubatch & ubatch ,
9145- int il) {
9146- const auto * kv_state = self->get_state_recurrent();
9140+ const llm_graph_context * self,
9141+ const llama_kv_cache_recurrent_state * kv_state ,
9142+ ggml_cgraph * gf ,
9143+ ggml_tensor * cur ,
9144+ ggml_tensor * state_copy ,
9145+ const llama_model & model ,
9146+ const llama_ubatch & ubatch,
9147+ int il) {
91479148
91489149 const auto kv_head = kv_state->get_head();
91499150 auto * ctx0 = self->ctx0;
@@ -9170,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context {
91709171 ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91719172
91729173 // (ab)using the KV cache to store the states
9173- ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
9174+ ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs, false, kv_state );
91749175 conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9175- ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
9176+ ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true, kv_state );
91769177 ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91779178
91789179 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9272,14 +9273,14 @@ struct llm_build_mamba : public llm_graph_context {
92729273 // static layer build function that enables other models to borrow this
92739274 // layer logic
92749275 static ggml_tensor * build_mamba2_layer(
9275- const llm_graph_context * self,
9276- ggml_cgraph * gf ,
9277- ggml_tensor * cur ,
9278- ggml_tensor * state_copy ,
9279- const llama_model & model ,
9280- const llama_ubatch & ubatch ,
9281- int il) {
9282- const auto * kv_state = self->get_state_recurrent();
9276+ const llm_graph_context * self,
9277+ const llama_kv_cache_recurrent_state * kv_state ,
9278+ ggml_cgraph * gf ,
9279+ ggml_tensor * cur ,
9280+ ggml_tensor * state_copy ,
9281+ const llama_model & model ,
9282+ const llama_ubatch & ubatch,
9283+ int il) {
92839284
92849285 const auto kv_head = kv_state->get_head();
92859286 auto * ctx0 = self->ctx0;
@@ -9303,10 +9304,10 @@ struct llm_build_mamba : public llm_graph_context {
93039304
93049305 // (ab)using the KV cache to store the states
93059306 LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state conv\n", __func__, il);
9306- ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
9307+ ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs, false, kv_state );
93079308 conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
93089309 LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state ssm\n", __func__, il);
9309- ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
9310+ ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true, kv_state );
93109311 ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
93119312
93129313 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -12965,11 +12966,15 @@ struct llm_build_hybrid_mamba : public llm_graph_context {
1296512966
1296612967 inpL = build_inp_embd(model.tok_embd);
1296712968
12969+ // Get the recurrent kv state
12970+ const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
12971+ const auto * kv_state_recurrent = kv_state->get_state_recurrent();
12972+
1296812973 // Build the inputs in the recurrent cache
12969- ggml_tensor * state_copy = build_inp_s_copy();
12974+ ggml_tensor * state_copy = build_inp_s_copy(kv_state_recurrent );
1297012975
12971- // Build the inputs in the attention cache
12972- auto * inp_attn = build_attn_inp_kv_unified ();
12976+ // Build the attention inputs in the hybrid attention cache
12977+ auto * inp_attn = build_attn_inp_kv_hybrid_recurrent ();
1297312978
1297412979 // Positional embeddings populated if rope enabled
1297512980 ggml_tensor * inp_pos = nullptr;
@@ -12989,9 +12994,9 @@ struct llm_build_hybrid_mamba : public llm_graph_context {
1298912994 if (hparams.recurrent_layer(il)) {
1299012995 // ssm layer //
1299112996 if (use_mamba2) {
12992- cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il);
12997+ cur = llm_build_mamba::build_mamba2_layer(this, kv_state_recurrent, gf, cur, state_copy, model, ubatch, il);
1299312998 } else {
12994- cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il);
12999+ cur = llm_build_mamba::build_mamba_layer(this, kv_state_recurrent, gf, cur, state_copy, model, ubatch, il);
1299513000 }
1299613001 } else {
1299713002 // attention layer //
0 commit comments