Skip to content

Commit 2834711

Browse files
committed
fix: Fix all models for kv state refactors
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 552a5ce commit 2834711

File tree

1 file changed

+32
-27
lines changed

1 file changed

+32
-27
lines changed

src/llama-model.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9086,6 +9086,7 @@ struct llm_build_mamba : public llm_graph_context {
90869086

90879087
ggml_tensor * state_copy = build_inp_s_copy();
90889088

9089+
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
90899090
for (int il = 0; il < n_layer; ++il) {
90909091
// norm
90919092
cur = build_norm(inpL,
@@ -9094,9 +9095,9 @@ struct llm_build_mamba : public llm_graph_context {
90949095
cb(cur, "attn_norm", il);
90959096

90969097
if (use_mamba2) {
9097-
cur = build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il);
9098+
cur = build_mamba2_layer(this, kv_state, gf, cur, state_copy, model, ubatch, il);
90989099
} else {
9099-
cur = build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il);
9100+
cur = build_mamba_layer(this, kv_state, gf, cur, state_copy, model, ubatch, il);
91009101
}
91019102

91029103
if (il == n_layer - 1) {
@@ -9136,14 +9137,14 @@ struct llm_build_mamba : public llm_graph_context {
91369137
// static layer build function that enables other models to borrow this
91379138
// layer logic
91389139
static ggml_tensor * build_mamba_layer(
9139-
const llm_graph_context * self,
9140-
ggml_cgraph * gf,
9141-
ggml_tensor * cur,
9142-
ggml_tensor * state_copy,
9143-
const llama_model & model,
9144-
const llama_ubatch & ubatch,
9145-
int il) {
9146-
const auto * kv_state = self->get_state_recurrent();
9140+
const llm_graph_context * self,
9141+
const llama_kv_cache_recurrent_state * kv_state,
9142+
ggml_cgraph * gf,
9143+
ggml_tensor * cur,
9144+
ggml_tensor * state_copy,
9145+
const llama_model & model,
9146+
const llama_ubatch & ubatch,
9147+
int il) {
91479148

91489149
const auto kv_head = kv_state->get_head();
91499150
auto * ctx0 = self->ctx0;
@@ -9170,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context {
91709171
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91719172

91729173
// (ab)using the KV cache to store the states
9173-
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
9174+
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs, false, kv_state);
91749175
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9175-
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
9176+
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true, kv_state);
91769177
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91779178

91789179
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9272,14 +9273,14 @@ struct llm_build_mamba : public llm_graph_context {
92729273
// static layer build function that enables other models to borrow this
92739274
// layer logic
92749275
static ggml_tensor * build_mamba2_layer(
9275-
const llm_graph_context * self,
9276-
ggml_cgraph * gf,
9277-
ggml_tensor * cur,
9278-
ggml_tensor * state_copy,
9279-
const llama_model & model,
9280-
const llama_ubatch & ubatch,
9281-
int il) {
9282-
const auto * kv_state = self->get_state_recurrent();
9276+
const llm_graph_context * self,
9277+
const llama_kv_cache_recurrent_state * kv_state,
9278+
ggml_cgraph * gf,
9279+
ggml_tensor * cur,
9280+
ggml_tensor * state_copy,
9281+
const llama_model & model,
9282+
const llama_ubatch & ubatch,
9283+
int il) {
92839284

92849285
const auto kv_head = kv_state->get_head();
92859286
auto * ctx0 = self->ctx0;
@@ -9303,10 +9304,10 @@ struct llm_build_mamba : public llm_graph_context {
93039304

93049305
// (ab)using the KV cache to store the states
93059306
LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state conv\n", __func__, il);
9306-
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
9307+
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs, false, kv_state);
93079308
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
93089309
LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state ssm\n", __func__, il);
9309-
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
9310+
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true, kv_state);
93109311
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
93119312

93129313
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -12965,11 +12966,15 @@ struct llm_build_hybrid_mamba : public llm_graph_context {
1296512966

1296612967
inpL = build_inp_embd(model.tok_embd);
1296712968

12969+
// Get the recurrent kv state
12970+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
12971+
const auto * kv_state_recurrent = kv_state->get_state_recurrent();
12972+
1296812973
// Build the inputs in the recurrent cache
12969-
ggml_tensor * state_copy = build_inp_s_copy();
12974+
ggml_tensor * state_copy = build_inp_s_copy(kv_state_recurrent);
1297012975

12971-
// Build the inputs in the attention cache
12972-
auto * inp_attn = build_attn_inp_kv_unified();
12976+
// Build the attention inputs in the hybrid attention cache
12977+
auto * inp_attn = build_attn_inp_kv_hybrid_recurrent();
1297312978

1297412979
// Positional embeddings populated if rope enabled
1297512980
ggml_tensor * inp_pos = nullptr;
@@ -12989,9 +12994,9 @@ struct llm_build_hybrid_mamba : public llm_graph_context {
1298912994
if (hparams.recurrent_layer(il)) {
1299012995
// ssm layer //
1299112996
if (use_mamba2) {
12992-
cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il);
12997+
cur = llm_build_mamba::build_mamba2_layer(this, kv_state_recurrent, gf, cur, state_copy, model, ubatch, il);
1299312998
} else {
12994-
cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il);
12999+
cur = llm_build_mamba::build_mamba_layer(this, kv_state_recurrent, gf, cur, state_copy, model, ubatch, il);
1299513000
}
1299613001
} else {
1299713002
// attention layer //

0 commit comments

Comments
 (0)