44#include  " llama-batch.h" 
55#include  " llama-cparams.h" 
66
7- #include  " llama-kv-cache-unified .h" 
8- #include  " llama-kv-cache-unified- iswa.h" 
7+ #include  " llama-kv-cache.h" 
8+ #include  " llama-kv-cache-iswa.h" 
99#include  " llama-memory-hybrid.h" 
1010#include  " llama-memory-recurrent.h" 
1111
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277277                for  (int  s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
278278                    const  llama_seq_id s0 = ubatch->seq_id [i0][0 ];
279279
280-                     //  TODO: reimplement this like in llama_kv_cache_unified 
280+                     //  TODO: reimplement this like in llama_kv_cache 
281281                    if  (s0 == s1 && (!cparams.causal_attn  || ubatch->pos [i0] <= ubatch->pos [i1])) {
282282                        if  (hparams.use_alibi ) {
283283                            f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
294294    }
295295}
296296
297- void  llm_graph_input_attn_kv_unified ::set_input (const  llama_ubatch * ubatch) {
297+ void  llm_graph_input_attn_kv ::set_input (const  llama_ubatch * ubatch) {
298298    mctx->set_input_k_idxs (self_k_idxs, ubatch);
299299    mctx->set_input_v_idxs (self_v_idxs, ubatch);
300300
301301    mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
302302}
303303
304- bool  llm_graph_input_attn_kv_unified ::can_reuse (const  llm_graph_params & params) {
305-     const  auto  * mctx = static_cast <const  llama_kv_cache_unified_context  *>(params.mctx );
304+ bool  llm_graph_input_attn_kv ::can_reuse (const  llm_graph_params & params) {
305+     const  auto  * mctx = static_cast <const  llama_kv_cache_context  *>(params.mctx );
306306
307307    this ->mctx  = mctx;
308308
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
319319    return  res;
320320}
321321
322- void  llm_graph_input_attn_kv_unified_iswa ::set_input (const  llama_ubatch * ubatch) {
322+ void  llm_graph_input_attn_kv_iswa ::set_input (const  llama_ubatch * ubatch) {
323323    mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
324324    mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
325325
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
331331    mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
332332}
333333
334- bool  llm_graph_input_attn_kv_unified_iswa ::can_reuse (const  llm_graph_params & params) {
335-     const  auto  * mctx = static_cast <const  llama_kv_cache_unified_iswa_context  *>(params.mctx );
334+ bool  llm_graph_input_attn_kv_iswa ::can_reuse (const  llm_graph_params & params) {
335+     const  auto  * mctx = static_cast <const  llama_kv_cache_iswa_context  *>(params.mctx );
336336
337337    this ->mctx  = mctx;
338338
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
11861186}
11871187
11881188ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const  {
1189-     const  auto  * mctx_cur = static_cast <const  llama_kv_cache_unified_context  *>(mctx);
1189+     const  auto  * mctx_cur = static_cast <const  llama_kv_cache_context  *>(mctx);
11901190
11911191    auto  inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
11921192
@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
13991399    return  cur;
14001400}
14011401
1402- static  std::unique_ptr<llm_graph_input_attn_kv_unified>  build_attn_inp_kv_unified_impl (
1402+ static  std::unique_ptr<llm_graph_input_attn_kv>  build_attn_inp_kv_impl (
14031403           ggml_context * ctx0,
14041404     const  llama_ubatch & ubatch,
14051405    const  llama_hparams & hparams,
14061406    const  llama_cparams & cparams,
1407-     const  llama_kv_cache_unified_context  * mctx_cur) {
1407+     const  llama_kv_cache_context  * mctx_cur) {
14081408
1409-     auto  inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, mctx_cur);
1409+     auto  inp = std::make_unique<llm_graph_input_attn_kv >(hparams, cparams, mctx_cur);
14101410
14111411    {
1412-         GGML_ASSERT (hparams.swa_type  == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa  for SWA"  );
1412+         GGML_ASSERT (hparams.swa_type  == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_iswa  for SWA"  );
14131413
14141414        const  auto  n_kv     = mctx_cur->get_n_kv ();
14151415        const  auto  n_tokens = ubatch.n_tokens ;
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
14271427    return  inp;
14281428}
14291429
1430- llm_graph_input_attn_kv_unified  * llm_graph_context::build_attn_inp_kv_unified  () const  {
1431-     const  auto  * mctx_cur = static_cast <const  llama_kv_cache_unified_context  *>(mctx);
1430+ llm_graph_input_attn_kv  * llm_graph_context::build_attn_inp_kv  () const  {
1431+     const  auto  * mctx_cur = static_cast <const  llama_kv_cache_context  *>(mctx);
14321432
1433-     auto  inp = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
1433+     auto  inp = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
14341434
1435-     return  (llm_graph_input_attn_kv_unified  *) res->add_input (std::move (inp));
1435+     return  (llm_graph_input_attn_kv  *) res->add_input (std::move (inp));
14361436}
14371437
14381438ggml_tensor * llm_graph_context::build_attn (
1439-         llm_graph_input_attn_kv_unified  * inp,
1439+         llm_graph_input_attn_kv  * inp,
14401440        ggml_tensor * wo,
14411441        ggml_tensor * wo_b,
14421442        ggml_tensor * q_cur,
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
14881488}
14891489
14901490ggml_tensor * llm_graph_context::build_attn (
1491-         llm_graph_input_attn_kv_unified_iswa  * inp,
1491+         llm_graph_input_attn_kv_iswa  * inp,
14921492        ggml_tensor * wo,
14931493        ggml_tensor * wo_b,
14941494        ggml_tensor * q_cur,
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
15131513}
15141514
15151515ggml_tensor * llm_graph_context::build_attn_with_sinks (
1516-         llm_graph_input_attn_kv_unified_iswa  * inp,
1516+         llm_graph_input_attn_kv_iswa  * inp,
15171517        ggml_tensor * wo,
15181518        ggml_tensor * wo_b,
15191519        ggml_tensor * q_cur,
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
16361636//  TODO: maybe separate the inner implementation into a separate function
16371637//        like with the non-sliding window equivalent
16381638//        once sliding-window hybrid caches are a thing.
1639- llm_graph_input_attn_kv_unified_iswa  * llm_graph_context::build_attn_inp_kv_unified_iswa  () const  {
1640-     const  auto  * mctx_cur = static_cast <const  llama_kv_cache_unified_iswa_context  *>(mctx);
1639+ llm_graph_input_attn_kv_iswa  * llm_graph_context::build_attn_inp_kv_iswa  () const  {
1640+     const  auto  * mctx_cur = static_cast <const  llama_kv_cache_iswa_context  *>(mctx);
16411641
1642-     auto  inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa >(hparams, cparams, mctx_cur);
1642+     auto  inp = std::make_unique<llm_graph_input_attn_kv_iswa >(hparams, cparams, mctx_cur);
16431643
16441644    const  auto  n_stream = cparams.kv_unified  ? 1  : ubatch.n_seqs_unq ;
16451645
@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
16561656    }
16571657
16581658    {
1659-         GGML_ASSERT (hparams.swa_type  != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified  for non-SWA"  );
1659+         GGML_ASSERT (hparams.swa_type  != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache  for non-SWA"  );
16601660
16611661        const  auto  n_kv = mctx_cur->get_swa ()->get_n_kv ();
16621662
@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
16691669        inp->self_kq_mask_swa_cnv  = cparams.flash_attn  ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
16701670    }
16711671
1672-     return  (llm_graph_input_attn_kv_unified_iswa  *) res->add_input (std::move (inp));
1672+     return  (llm_graph_input_attn_kv_iswa  *) res->add_input (std::move (inp));
16731673}
16741674
16751675ggml_tensor * llm_graph_context::build_rs (
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
17921792    const  auto  * mctx_cur = static_cast <const  llama_memory_hybrid_context *>(mctx);
17931793
17941794    auto  inp_rs   = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
1795-     auto  inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
1795+     auto  inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
17961796
17971797    auto  inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
17981798
0 commit comments