@@ -449,6 +449,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
449449 cvec (params.cvec),
450450 loras (params.loras),
451451 memory (params.memory),
452+ mstate (params.mstate),
452453 cross (params.cross),
453454 cb_func (params.cb),
454455 res (std::make_unique<llm_graph_result>()) {
@@ -1027,9 +1028,13 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10271028ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
10281029 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
10291030
1031+ const llama_kv_cache_unified_state_i * kv_state = static_cast <const llama_kv_cache_unified_state_i *>(mstate);
1032+
1033+ const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate () : nullptr ;
1034+
10301035 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
10311036
1032- const auto n_kv = kv_self->get_n_kv ();
1037+ const auto n_kv = kv_self->get_n_kv (cstate );
10331038
10341039 auto & cur = inp->pos_bucket ;
10351040
@@ -1233,12 +1238,16 @@ ggml_tensor * llm_graph_context::build_attn(
12331238llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
12341239 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
12351240
1241+ const llama_kv_cache_unified_state_i * kv_state = static_cast <const llama_kv_cache_unified_state_i *>(mstate);
1242+
1243+ const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate () : nullptr ;
1244+
12361245 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12371246
12381247 {
12391248 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
12401249
1241- const auto n_kv = kv_self->get_n_kv ();
1250+ const auto n_kv = kv_self->get_n_kv (cstate );
12421251
12431252 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
12441253 // cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1270,17 +1279,21 @@ ggml_tensor * llm_graph_context::build_attn(
12701279
12711280 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
12721281
1282+ const llama_kv_cache_unified_state_i * kv_state = static_cast <const llama_kv_cache_unified_state_i *>(mstate);
1283+
1284+ const llama_kv_cache_unified::compute_state * cstate = kv_state ? kv_state->get_cstate () : nullptr ;
1285+
12731286 // store to KV cache
12741287 {
1275- ggml_build_forward_expand (gf, kv_self->cpy_k (ctx0, k_cur, il));
1276- ggml_build_forward_expand (gf, kv_self->cpy_v (ctx0, v_cur, il));
1288+ ggml_build_forward_expand (gf, kv_self->cpy_k (cstate, ctx0, k_cur, il));
1289+ ggml_build_forward_expand (gf, kv_self->cpy_v (cstate, ctx0, v_cur, il));
12771290 }
12781291
12791292 const auto & kq_mask = inp->get_kq_mask ();
12801293
12811294 ggml_tensor * q = q_cur;
1282- ggml_tensor * k = kv_self->get_k (ctx0, il);
1283- ggml_tensor * v = kv_self->get_v (ctx0, il);
1295+ ggml_tensor * k = kv_self->get_k (cstate, ctx0, il);
1296+ ggml_tensor * v = kv_self->get_v (cstate, ctx0, il);
12841297
12851298 ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12861299 cb (cur, " kqv_out" , il);
@@ -1303,10 +1316,15 @@ ggml_tensor * llm_graph_context::build_attn(
13031316llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
13041317 const llama_kv_cache_unified_iswa * kv_self = static_cast <const llama_kv_cache_unified_iswa *>(memory);
13051318
1319+ const llama_kv_cache_unified_iswa_state_i * kv_state = static_cast <const llama_kv_cache_unified_iswa_state_i *>(mstate);
1320+
1321+ const llama_kv_cache_unified::compute_state * cstate_base = kv_state ? kv_state->get_cstate_base () : nullptr ;
1322+ const llama_kv_cache_unified::compute_state * cstate_swa = kv_state ? kv_state->get_cstate_swa () : nullptr ;
1323+
13061324 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
13071325
13081326 {
1309- const auto n_kv = kv_self->get_kv_base ()->get_n_kv ();
1327+ const auto n_kv = kv_self->get_kv_base ()->get_n_kv (cstate_base );
13101328
13111329 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13121330 // cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1318,7 +1336,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13181336 {
13191337 GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified for non-SWA" );
13201338
1321- const auto n_kv = kv_self->get_kv_swa ()->get_n_kv ();
1339+ const auto n_kv = kv_self->get_kv_swa ()->get_n_kv (cstate_swa );
13221340
13231341 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13241342 // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1354,17 +1372,24 @@ ggml_tensor * llm_graph_context::build_attn(
13541372
13551373 const auto * kv = is_swa ? kv_self->get_kv_swa () : kv_self->get_kv_base ();
13561374
1375+ const llama_kv_cache_unified_iswa_state_i * kv_state = static_cast <const llama_kv_cache_unified_iswa_state_i *>(mstate);
1376+
1377+ const llama_kv_cache_unified::compute_state * cstate_base = kv_state ? kv_state->get_cstate_base () : nullptr ;
1378+ const llama_kv_cache_unified::compute_state * cstate_swa = kv_state ? kv_state->get_cstate_swa () : nullptr ;
1379+
1380+ const llama_kv_cache_unified::compute_state * cstate = is_swa ? cstate_swa : cstate_base;
1381+
13571382 // store to KV cache
13581383 {
1359- ggml_build_forward_expand (gf, kv->cpy_k (ctx0, k_cur, il));
1360- ggml_build_forward_expand (gf, kv->cpy_v (ctx0, v_cur, il));
1384+ ggml_build_forward_expand (gf, kv->cpy_k (cstate, ctx0, k_cur, il));
1385+ ggml_build_forward_expand (gf, kv->cpy_v (cstate, ctx0, v_cur, il));
13611386 }
13621387
13631388 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
13641389
13651390 ggml_tensor * q = q_cur;
1366- ggml_tensor * k = kv->get_k (ctx0, il);
1367- ggml_tensor * v = kv->get_v (ctx0, il);
1391+ ggml_tensor * k = kv->get_k (cstate, ctx0, il);
1392+ ggml_tensor * v = kv->get_v (cstate, ctx0, il);
13681393
13691394 ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13701395 cb (cur, " kqv_out" , il);
0 commit comments