Skip to content

Commit 9ba8615

Browse files
committed
refactor: Remove n_embd_k/v_gqa from recurrent cache
This is no longer needed now that there are separate implementations #13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent e15fa60 commit 9ba8615

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969
continue;
7070
}
7171

72-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
73-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
74-
7572
const char * dev_name = "CPU";
7673

7774
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
@@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
9087
throw std::runtime_error("failed to create ggml context for kv cache");
9188
}
9289

93-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
94-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
90+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size);
91+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size);
9592
ggml_format_name(k, "cache_k_l%d", i);
9693
ggml_format_name(v, "cache_v_l%d", i);
9794
k_l[i] = k;
@@ -756,14 +753,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
756753
// Iterate and write all the keys first, each row is a cell
757754
// Get whole range at a time
758755
for (uint32_t il = 0; il < n_layer; ++il) {
759-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
760756

761757
// Write key type
762758
const int32_t k_type_i = (int32_t)k_l[il]->type;
763759
io.write(&k_type_i, sizeof(k_type_i));
764760

765761
// Write row size of key
766-
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
762+
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s());
767763
io.write(&k_size_row, sizeof(k_size_row));
768764

769765
// Read each range of cells of k_size length each into tmp_buf and write out
@@ -776,14 +772,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
776772

777773
if (!v_trans) {
778774
for (uint32_t il = 0; il < n_layer; ++il) {
779-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
780775

781776
// Write value type
782777
const int32_t v_type_i = (int32_t)v_l[il]->type;
783778
io.write(&v_type_i, sizeof(v_type_i));
784779

785780
// Write row size of value
786-
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
781+
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s());
787782
io.write(&v_size_row, sizeof(v_size_row));
788783

789784
// Read each range of cells of v_size length each into tmp_buf and write out
@@ -797,7 +792,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
797792
// When v is transposed, we also need the element size and get the element ranges from each row
798793
const uint32_t kv_size = size;
799794
for (uint32_t il = 0; il < n_layer; ++il) {
800-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
795+
const uint32_t n_embd_v_s = hparams.n_embd_v_s();
801796

802797
// Write value type
803798
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -808,10 +803,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
808803
io.write(&v_size_el, sizeof(v_size_el));
809804

810805
// Write GQA embedding size
811-
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
806+
io.write(&n_embd_v_s, sizeof(n_embd_v_s));
812807

813808
// For each row, we get the element values of each cell
814-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
809+
for (uint32_t j = 0; j < n_embd_v_s; ++j) {
815810
// Read each range of cells of v_size_el length each into tmp_buf and write out
816811
for (const auto & range : cell_ranges) {
817812
const size_t range_size = range.second - range.first;
@@ -944,7 +939,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
944939

945940
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
946941
for (uint32_t il = 0; il < n_layer; ++il) {
947-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
948942

949943
// Read type of key
950944
int32_t k_type_i_ref;
@@ -958,7 +952,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
958952
// Read row size of key
959953
uint64_t k_size_row_ref;
960954
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
961-
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
955+
const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s());
962956
if (k_size_row != k_size_row_ref) {
963957
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
964958
return false;
@@ -972,7 +966,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
972966

973967
if (!v_trans) {
974968
for (uint32_t il = 0; il < n_layer; ++il) {
975-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
976969

977970
// Read type of value
978971
int32_t v_type_i_ref;
@@ -986,7 +979,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
986979
// Read row size of value
987980
uint64_t v_size_row_ref;
988981
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
989-
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
982+
const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s());
990983
if (v_size_row != v_size_row_ref) {
991984
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
992985
return false;
@@ -1000,7 +993,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1000993
} else {
1001994
// For each layer, read the values for each cell (transposed)
1002995
for (uint32_t il = 0; il < n_layer; ++il) {
1003-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
996+
const uint32_t n_embd_v_s = hparams.n_embd_v_s();
1004997

1005998
// Read type of value
1006999
int32_t v_type_i_ref;
@@ -1020,17 +1013,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10201013
return false;
10211014
}
10221015

1023-
// Read GQA embedding size
1024-
uint32_t n_embd_v_gqa_ref;
1025-
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1026-
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1027-
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1016+
// Read state embedding size
1017+
uint32_t n_embd_v_s_ref;
1018+
io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref));
1019+
if (n_embd_v_s != n_embd_v_s_ref) {
1020+
LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il);
10281021
return false;
10291022
}
10301023

10311024
if (cell_count) {
10321025
// For each row in the transposed matrix, read the values for the whole cell range
1033-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1026+
for (uint32_t j = 0; j < n_embd_v_s; ++j) {
10341027
const size_t dst_offset = (head + j * size) * v_size_el;
10351028
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
10361029
}

0 commit comments

Comments
 (0)