Skip to content

Commit 88213a9

Browse files
committed
refactor: Rename llama_kv_cache_[recurrent|hybrid_recurrent] to remove kv cache
This removes the notion of "kv" from the interface names for these memory types. There are still many references to kv in the implementation of the recurrent memory which will need further adjustment. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 8488f5e commit 88213a9

12 files changed

+483
-485
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ add_library(llama
2222
llama-io.cpp
2323
llama-kv-cache-unified.cpp
2424
llama-kv-cache-unified-iswa.cpp
25-
llama-kv-cache-recurrent.cpp
26-
llama-kv-cache-hybrid-recurrent.cpp
2725
llama-memory.cpp
26+
llama-memory-hybrid.cpp
27+
llama-memory-recurrent.cpp
2828
llama-mmap.cpp
2929
llama-model-loader.cpp
3030
llama-model-saver.cpp

src/llama-arch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1831,7 +1831,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
18311831
}
18321832
}
18331833

1834-
bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) {
1834+
bool llm_arch_is_hybrid(const llm_arch & arch) {
18351835
// TODO: There are currently no hybrid models! Once there are, this will be
18361836
// the place to identify them
18371837
switch (arch) {

src/llama-arch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,4 +442,4 @@ llm_arch llm_arch_from_string(const std::string & name);
442442
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
443443

444444
bool llm_arch_is_recurrent(const llm_arch& arch);
445-
bool llm_arch_is_hybrid_recurrent(const llm_arch& arch);
445+
bool llm_arch_is_hybrid (const llm_arch& arch);

src/llama-graph.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
#include "llama-kv-cache-unified.h"
88
#include "llama-kv-cache-unified-iswa.h"
9-
#include "llama-kv-cache-recurrent.h"
10-
#include "llama-kv-cache-hybrid-recurrent.h"
9+
#include "llama-memory-hybrid.h"
10+
#include "llama-memory-recurrent.h"
1111

1212
#include <cassert>
1313
#include <cmath>
@@ -1050,7 +1050,7 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10501050
}
10511051

10521052
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1053-
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1053+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate);
10541054

10551055
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state);
10561056

@@ -1447,7 +1447,7 @@ ggml_tensor * llm_graph_context::build_attn(
14471447
ggml_build_forward_expand(gf, k_cur);
14481448
ggml_build_forward_expand(gf, v_cur);
14491449

1450-
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
1450+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
14511451

14521452
// store to KV cache
14531453
{
@@ -1553,7 +1553,7 @@ ggml_tensor * llm_graph_context::build_rs(
15531553
}
15541554

15551555
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1556-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1556+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
15571557

15581558
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
15591559

@@ -1572,7 +1572,7 @@ ggml_tensor * llm_graph_context::build_rs(
15721572
int32_t state_size,
15731573
int32_t n_seqs,
15741574
bool avoid_copies) const {
1575-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1575+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
15761576

15771577
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
15781578
}
@@ -1584,7 +1584,7 @@ ggml_tensor * llm_graph_context::build_rs(
15841584
int32_t state_size,
15851585
int32_t n_seqs,
15861586
bool avoid_copies) const {
1587-
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
1587+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recurrent();
15881588

15891589
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
15901590
}
@@ -1594,7 +1594,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15941594
ggml_cgraph * gf,
15951595
const llama_ubatch & ubatch,
15961596
int il) const {
1597-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1597+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
15981598

15991599
const auto token_shift_count = hparams.token_shift_count;
16001600

@@ -1615,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
16151615
ggml_tensor * token_shift,
16161616
const llama_ubatch & ubatch,
16171617
int il) const {
1618-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1618+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
16191619

16201620
const auto token_shift_count = hparams.token_shift_count;
16211621
const auto n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ struct llama_memory_state_i;
2121

2222
class llama_kv_cache_unified_state;
2323
class llama_kv_cache_unified_iswa_state;
24-
class llama_kv_cache_recurrent_state;
25-
class llama_kv_cache_hybrid_recurrent_state;
24+
class llama_memory_recurrent_state;
25+
class llama_memory_hybrid_state;
2626

2727
// certain models (typically multi-modal) can produce different types of graphs
2828
enum llm_graph_type {
@@ -191,14 +191,14 @@ class llm_graph_input_cls : public llm_graph_input_i {
191191

192192
class llm_graph_input_rs : public llm_graph_input_i {
193193
public:
194-
llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
194+
llm_graph_input_rs(const llama_memory_recurrent_state * kv_state) : kv_state(kv_state) {}
195195
virtual ~llm_graph_input_rs() = default;
196196

197197
void set_input(const llama_ubatch * ubatch) override;
198198

199199
ggml_tensor * s_copy; // I32 [kv_size]
200200

201-
const llama_kv_cache_recurrent_state * kv_state;
201+
const llama_memory_recurrent_state * kv_state;
202202
};
203203

204204
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -306,7 +306,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
306306
llm_graph_input_mem_hybrid(
307307
const llama_hparams & hparams,
308308
const llama_cparams & cparams,
309-
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
309+
const llama_memory_hybrid_state * kv_state) :
310310
hparams(hparams),
311311
cparams(cparams),
312312
kv_state(kv_state) {
@@ -325,7 +325,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
325325
const llama_hparams & hparams;
326326
const llama_cparams & cparams;
327327

328-
const llama_kv_cache_hybrid_recurrent_state * kv_state;
328+
const llama_memory_hybrid_state * kv_state;
329329
};
330330

331331
//
@@ -635,11 +635,11 @@ struct llm_graph_context {
635635
//
636636

637637
// TODO: avoid notion of "kv"
638-
// TODO: move this implementation to llama_kv_cache_recurrent.
638+
// TODO: move this implementation to llama_memory_recurrent.
639639
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
640640
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
641641
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
642-
// `llama_kv_cache_recurrent`
642+
// `llama_memory_recurrent`
643643
ggml_tensor * build_rs(
644644
ggml_cgraph * gf,
645645
ggml_tensor * s,

0 commit comments

Comments
 (0)