Skip to content

Commit 31eeb3d

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent f8bcfe0 commit 31eeb3d

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7676
continue;
7777
}
7878

79-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
80-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
79+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
80+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
8181

8282
const char * dev_name = "CPU";
8383

@@ -1346,7 +1346,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13461346
for (const auto & layer : layers) {
13471347
const uint32_t il = layer.il;
13481348

1349-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1349+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
13501350

13511351
// Write key type
13521352
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1368,7 +1368,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13681368
for (const auto & layer : layers) {
13691369
const uint32_t il = layer.il;
13701370

1371-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1371+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13721372

13731373
// Write value type
13741374
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1392,7 +1392,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13921392
for (const auto & layer : layers) {
13931393
const uint32_t il = layer.il;
13941394

1395-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1395+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13961396

13971397
// Write value type
13981398
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1526,7 +1526,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15261526
for (const auto & layer : layers) {
15271527
const uint32_t il = layer.il;
15281528

1529-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1529+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
15301530

15311531
// Read type of key
15321532
int32_t k_type_i_ref;
@@ -1556,7 +1556,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15561556
for (const auto & layer : layers) {
15571557
const uint32_t il = layer.il;
15581558

1559-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1559+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15601560

15611561
// Read type of value
15621562
int32_t v_type_i_ref;
@@ -1586,7 +1586,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15861586
for (const auto & layer : layers) {
15871587
const uint32_t il = layer.il;
15881588

1589-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1589+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15901590

15911591
// Read type of value
15921592
int32_t v_type_i_ref;
@@ -1881,8 +1881,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18811881
continue;
18821882
}
18831883

1884-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1885-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1884+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
1885+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
18861886

18871887
const char * dev_name = "CPU";
18881888

@@ -2586,7 +2586,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
25862586
// Iterate and write all the keys first, each row is a cell
25872587
// Get whole range at a time
25882588
for (uint32_t il = 0; il < n_layer; ++il) {
2589-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2589+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
25902590

25912591
// Write key type
25922592
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -2606,7 +2606,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26062606

26072607
if (!v_trans) {
26082608
for (uint32_t il = 0; il < n_layer; ++il) {
2609-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2609+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
26102610

26112611
// Write value type
26122612
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2627,7 +2627,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26272627
// When v is transposed, we also need the element size and get the element ranges from each row
26282628
const uint32_t kv_size = size;
26292629
for (uint32_t il = 0; il < n_layer; ++il) {
2630-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2630+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
26312631

26322632
// Write value type
26332633
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2774,7 +2774,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27742774

27752775
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
27762776
for (uint32_t il = 0; il < n_layer; ++il) {
2777-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2777+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
27782778

27792779
// Read type of key
27802780
int32_t k_type_i_ref;
@@ -2802,7 +2802,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28022802

28032803
if (!v_trans) {
28042804
for (uint32_t il = 0; il < n_layer; ++il) {
2805-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2805+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
28062806

28072807
// Read type of value
28082808
int32_t v_type_i_ref;
@@ -2830,7 +2830,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28302830
} else {
28312831
// For each layer, read the values for each cell (transposed)
28322832
for (uint32_t il = 0; il < n_layer; ++il) {
2833-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2833+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
28342834

28352835
// Read type of value
28362836
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)