@@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
21562156}
21572157
21582158// find how many cells are currently in use
2159- static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160- for (uint32_t i = cache.size - 1; i > 0; --i) {
2161- if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
2162- return i + 1;
2159+ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160+ for (uint32_t i = cache.size; i > 0; --i) {
2161+ const llama_kv_cell & cell = cache.cells[i - 1];
2162+
2163+ if (cell.pos >= 0 && !cell.is_empty()) {
2164+ return i;
21632165 }
21642166 }
21652167
@@ -8178,7 +8180,7 @@ static int llama_decode_internal(
81788180 // a heuristic, to avoid attending the full cache if it is not yet utilized
81798181 // after enough generations, the benefit from this heuristic disappears
81808182 // if we start defragmenting the cache, the benefit from this will be more important
8181- kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32 , GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8183+ kv_self.n = std::min(cparams.n_ctx, std::max(32u , GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
81828184 //kv_self.n = llama_kv_cache_cell_max(kv_self);
81838185
81848186 //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
1261512617 const size_t s_logits = ctx->logits.capacity() * sizeof(float);
1261612618 const size_t s_embedding_size = sizeof(size_t);
1261712619 const size_t s_embedding = ctx->embedding.size() * sizeof(float);
12618- const size_t s_kv_size = sizeof(size_t);
12619- const size_t s_kv_ntok = sizeof(int);
12620+ const size_t s_kv_buf_size = sizeof(size_t);
12621+ const size_t s_kv_head = sizeof(uint32_t);
12622+ const size_t s_kv_size = sizeof(uint32_t);
12623+ const size_t s_kv_used = sizeof(uint32_t);
1262012624 const size_t s_kv = ctx->kv_self.total_size();
12625+ // TODO: assume the max is more than 1 seq_id per KV cell
12626+ const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
12627+ const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
1262112628
1262212629 const size_t s_total = (
1262312630 + s_rng_size
@@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
1262612633 + s_logits
1262712634 + s_embedding_size
1262812635 + s_embedding
12636+ + s_kv_buf_size
12637+ + s_kv_head
1262912638 + s_kv_size
12630- + s_kv_ntok
12639+ + s_kv_used
1263112640 + s_kv
12641+ + s_kv_cells
1263212642 );
1263312643
1263412644 return s_total;
@@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1272812738 {
1272912739 const auto & kv_self = ctx->kv_self;
1273012740 const auto & hparams = ctx->model.hparams;
12731- const auto & cparams = ctx->cparams;
1273212741
1273312742 const uint32_t n_layer = hparams.n_layer;
1273412743 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1273512744 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12736- const uint32_t n_ctx = cparams.n_ctx;
1273712745
1273812746 const size_t kv_buf_size = kv_self.total_size();
12739- const uint32_t kv_head = kv_self.head ;
12747+ const uint32_t kv_head = llama_kv_cache_cell_max( kv_self) ;
1274012748 const uint32_t kv_size = kv_self.size;
1274112749 const uint32_t kv_used = kv_self.used;
1274212750
@@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1275612764
1275712765 // v is not contiguous, copy row by row
1275812766 const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12759- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx );
12767+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size );
1276012768
1276112769 tmp_buf.resize(v_row_size);
1276212770 for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
@@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1276612774 }
1276712775 }
1276812776
12769- for (uint32_t i = 0; i < kv_size ; ++i) {
12777+ for (uint32_t i = 0; i < kv_head ; ++i) {
1277012778 const auto & cell = kv_self.cells[i];
1277112779
1277212780 const llama_pos pos = cell.pos;
@@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1284212850 {
1284312851 const auto & kv_self = ctx->kv_self;
1284412852 const auto & hparams = ctx->model.hparams;
12845- const auto & cparams = ctx->cparams;
1284612853
1284712854 const uint32_t n_layer = hparams.n_layer;
1284812855 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1284912856 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12850- const uint32_t n_ctx = cparams.n_ctx;
1285112857
1285212858 size_t kv_buf_size;
1285312859 uint32_t kv_head;
@@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1287012876
1287112877 // v is not contiguous, copy row by row
1287212878 const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12873- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx );
12879+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size );
1287412880
1287512881 for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
1287612882 ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
@@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1287912885 }
1288012886 }
1288112887
12888+ GGML_ASSERT(kv_self.size == kv_size);
12889+
1288212890 ctx->kv_self.head = kv_head;
1288312891 ctx->kv_self.size = kv_size;
1288412892 ctx->kv_self.used = kv_used;
1288512893
1288612894 ctx->kv_self.cells.resize(kv_size);
1288712895
12888- for (uint32_t i = 0; i < kv_size ; ++i) {
12896+ for (uint32_t i = 0; i < kv_head ; ++i) {
1288912897 llama_pos pos;
1289012898 size_t seq_id_size;
1289112899
@@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1290112909 ctx->kv_self.cells[i].seq_id.insert(seq_id);
1290212910 }
1290312911 }
12912+
12913+ for (uint32_t i = kv_head; i < kv_size; ++i) {
12914+ ctx->kv_self.cells[i].pos = -1;
12915+ ctx->kv_self.cells[i].seq_id.clear();
12916+ }
1290412917 }
1290512918
1290612919 const size_t nread = inp - src;
0 commit comments