Skip to content

Commit 8e39e04

Browse files
committed
refactor!: Rename all k/v related values for recurrent/hybrid to r/s
Anywhere that "kv_<state|cell|size|etc>" is used, I've used the more generic "mem_" prefix. The specifics of "k" (key) translate to "r" (recurrent state) and "v" (value) translate to "s" (state-space embedding states). Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 88213a9 commit 8e39e04

File tree

9 files changed

+241
-241
lines changed

9 files changed

+241
-241
lines changed

src/llama-graph.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,15 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
242242
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
243243
GGML_UNUSED(ubatch);
244244

245-
const int64_t n_kv = kv_state->get_n_kv();
245+
const int64_t n_rs = mem_state->get_n_rs();
246246

247247
if (s_copy) {
248248
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
249249
int32_t * data = (int32_t *) s_copy->data;
250250

251251
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
252-
for (uint32_t i = 0; i < n_kv; ++i) {
253-
data[i] = kv_state->s_copy(i);
252+
for (uint32_t i = 0; i < n_rs; ++i) {
253+
data[i] = mem_state->s_copy(i);
254254
}
255255
}
256256
}
@@ -406,18 +406,18 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
406406

407407
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
408408
if (self_kq_mask) {
409-
kv_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
409+
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410410
}
411411

412-
const int64_t n_kv = kv_state->get_state_recurrent()->get_n_kv();
412+
const int64_t n_rs = mem_state->get_state_recurrent()->get_n_rs();
413413

414414
if (s_copy) {
415415
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
416416
int32_t * data = (int32_t *) s_copy->data;
417417

418418
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419-
for (uint32_t i = 0; i < n_kv; ++i) {
420-
data[i] = kv_state->get_state_recurrent()->s_copy(i);
419+
for (uint32_t i = 0; i < n_rs; ++i) {
420+
data[i] = mem_state->get_state_recurrent()->s_copy(i);
421421
}
422422
}
423423
}
@@ -1050,14 +1050,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10501050
}
10511051

10521052
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1053-
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate);
1053+
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
10541054

1055-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state);
1055+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
10561056

10571057
{
10581058
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10591059

1060-
const auto n_kv = inp->kv_state->get_state_attn()->get_n_kv();
1060+
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
10611061

10621062
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
10631063
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1067,9 +1067,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10671067
}
10681068

10691069
{
1070-
const auto n_kv = kv_state->get_state_recurrent()->get_n_kv();
1070+
const auto n_rs = mem_state->get_state_recurrent()->get_n_rs();
10711071

1072-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1072+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
10731073
ggml_set_input(inp->s_copy);
10741074
}
10751075

@@ -1557,9 +1557,9 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
15571557

15581558
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
15591559

1560-
const auto n_kv = kv_state->get_n_kv();
1560+
const auto n_rs = kv_state->get_n_rs();
15611561

1562-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1562+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
15631563
ggml_set_input(inp->s_copy);
15641564

15651565
return (llm_graph_input_rs *) res->add_input(std::move(inp));
@@ -1574,7 +1574,7 @@ ggml_tensor * llm_graph_context::build_rs(
15741574
bool avoid_copies) const {
15751575
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
15761576

1577-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1577+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
15781578
}
15791579

15801580
ggml_tensor * llm_graph_context::build_rs(
@@ -1586,7 +1586,7 @@ ggml_tensor * llm_graph_context::build_rs(
15861586
bool avoid_copies) const {
15871587
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recurrent();
15881588

1589-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1589+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
15901590
}
15911591

15921592
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1600,11 +1600,11 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
16001600

16011601
const int64_t n_seqs = ubatch.n_seqs;
16021602

1603-
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1603+
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
16041604

16051605
ggml_tensor * token_shift = build_rs(
16061606
inp, gf, token_shift_all,
1607-
hparams.n_embd_k_s(), n_seqs);
1607+
hparams.n_embd_r(), n_seqs);
16081608

16091609
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
16101610

@@ -1627,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
16271627
return ggml_cpy(
16281628
ctx0,
16291629
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1630-
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
1630+
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
16311631
);
16321632
}
16331633

src/llama-graph.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,14 @@ class llm_graph_input_cls : public llm_graph_input_i {
191191

192192
class llm_graph_input_rs : public llm_graph_input_i {
193193
public:
194-
llm_graph_input_rs(const llama_memory_recurrent_state * kv_state) : kv_state(kv_state) {}
194+
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
195195
virtual ~llm_graph_input_rs() = default;
196196

197197
void set_input(const llama_ubatch * ubatch) override;
198198

199199
ggml_tensor * s_copy; // I32 [kv_size]
200200

201-
const llama_memory_recurrent_state * kv_state;
201+
const llama_memory_recurrent_state * mem_state;
202202
};
203203

204204
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -306,10 +306,10 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
306306
llm_graph_input_mem_hybrid(
307307
const llama_hparams & hparams,
308308
const llama_cparams & cparams,
309-
const llama_memory_hybrid_state * kv_state) :
309+
const llama_memory_hybrid_state * mem_state) :
310310
hparams(hparams),
311311
cparams(cparams),
312-
kv_state(kv_state) {
312+
mem_state(mem_state) {
313313
}
314314
virtual ~llm_graph_input_mem_hybrid() = default;
315315

@@ -325,7 +325,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
325325
const llama_hparams & hparams;
326326
const llama_cparams & cparams;
327327

328-
const llama_memory_hybrid_state * kv_state;
328+
const llama_memory_hybrid_state * mem_state;
329329
};
330330

331331
//

src/llama-hparams.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68-
uint32_t llama_hparams::n_embd_k_s() const {
68+
uint32_t llama_hparams::n_embd_r() const {
6969
if (wkv_head_size != 0) {
7070
// for RWKV models
7171
return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
7676
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
7777
}
7878

79-
uint32_t llama_hparams::n_embd_v_s() const {
79+
uint32_t llama_hparams::n_embd_s() const {
8080
if (wkv_head_size != 0) {
8181
// corresponds to RWKV's wkv_states size
8282
return n_embd * wkv_head_size;

src/llama-hparams.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ struct llama_hparams {
184184

185185
// dimension of the rolling state embeddings
186186
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
187-
uint32_t n_embd_k_s() const;
187+
uint32_t n_embd_r() const;
188188

189189
// dimension of the recurrent state embeddings
190-
uint32_t n_embd_v_s() const;
190+
uint32_t n_embd_s() const;
191191

192192
// whether or not the given layer is recurrent (for hybrid models)
193193
bool recurrent_layer(uint32_t il) const;

src/llama-memory-hybrid.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,48 @@
1111
llama_memory_hybrid::llama_memory_hybrid(
1212
const llama_model & model,
1313
/* attn */
14-
ggml_type attn_type_k,
15-
ggml_type attn_type_v,
16-
bool attn_v_trans,
17-
uint32_t attn_kv_size,
18-
uint32_t attn_n_pad,
19-
uint32_t attn_n_swa,
20-
llama_swa_type attn_swa_type,
14+
ggml_type type_k,
15+
ggml_type type_v,
16+
bool v_trans,
17+
uint32_t kv_size,
18+
uint32_t n_pad,
19+
uint32_t n_swa,
20+
llama_swa_type swa_type,
2121
/* recurrent */
22-
ggml_type recurrent_type_k,
23-
ggml_type recurrent_type_v,
24-
uint32_t recurrent_kv_size,
22+
ggml_type type_r,
23+
ggml_type type_s,
24+
uint32_t rs_size,
2525
/* common */
2626
uint32_t n_seq_max,
2727
bool offload,
2828
/* layer filters */
29-
layer_filter_cb && attn_filter,
30-
layer_filter_cb && recurrent_filter) :
29+
layer_filter_cb && filter_attn,
30+
layer_filter_cb && filter_recurrent) :
3131
hparams(model.hparams),
3232
mem_attn(new llama_kv_cache_unified(
3333
model,
34-
attn_filter == nullptr ?
34+
filter_attn == nullptr ?
3535
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
36-
: attn_filter,
37-
attn_type_k,
38-
attn_type_v,
39-
attn_v_trans,
36+
: filter_attn,
37+
type_k,
38+
type_v,
39+
v_trans,
4040
offload,
41-
attn_kv_size,
41+
kv_size,
4242
n_seq_max,
43-
attn_n_pad,
44-
attn_n_swa,
45-
attn_swa_type
43+
n_pad,
44+
n_swa,
45+
swa_type
4646
)),
4747
mem_recurrent(new llama_memory_recurrent(
4848
model,
49-
recurrent_filter == nullptr ?
49+
filter_recurrent == nullptr ?
5050
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
51-
: recurrent_filter,
52-
recurrent_type_k,
53-
recurrent_type_v,
51+
: filter_recurrent,
52+
type_r,
53+
type_s,
5454
offload,
55-
recurrent_kv_size,
55+
rs_size,
5656
n_seq_max
5757
)) {}
5858

src/llama-memory-hybrid.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ class llama_memory_hybrid : public llama_memory_i {
2525
llama_memory_hybrid(
2626
const llama_model & model,
2727
/* attn */
28-
ggml_type attn_type_k,
29-
ggml_type attn_type_v,
30-
bool attn_v_trans,
31-
uint32_t attn_kv_size,
32-
uint32_t attn_n_pad,
33-
uint32_t attn_n_swa,
34-
llama_swa_type attn_swa_type,
28+
ggml_type type_k,
29+
ggml_type type_v,
30+
bool v_trans,
31+
uint32_t kv_size,
32+
uint32_t n_pad,
33+
uint32_t n_swa,
34+
llama_swa_type swa_type,
3535
/* recurrent */
36-
ggml_type recurrent_type_k,
37-
ggml_type recurrent_type_v,
38-
uint32_t recurrent_kv_size,
36+
ggml_type type_r,
37+
ggml_type type_s,
38+
uint32_t rs_size,
3939
/* common */
4040
uint32_t n_seq_max,
4141
bool offload,
4242
/* layer filters */
43-
layer_filter_cb && attn_filter = nullptr,
44-
layer_filter_cb && recurrent_filter = nullptr);
43+
layer_filter_cb && filter_attn = nullptr,
44+
layer_filter_cb && filter_recurrent = nullptr);
4545

4646
~llama_memory_hybrid() = default;
4747

0 commit comments

Comments
 (0)