Skip to content

Commit 12ee6db

Browse files
committed
llama : add llama_kv_self_seq_pos_min()
1 parent 0073157 commit 12ee6db

File tree

5 files changed

+63
-5
lines changed

5 files changed

+63
-5
lines changed

include/llama.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,10 +730,18 @@ extern "C" {
730730
llama_pos p1,
731731
int d);
732732

733+
// Returns the smallest position present in the KV cache for the specified sequence
734+
// This is typically non-zero only for SWA caches
735+
// Return -1 if the sequence is empty
736+
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
737+
struct llama_context * ctx,
738+
llama_seq_id seq_id);
739+
733740
// Returns the largest position present in the KV cache for the specified sequence
741+
// Return -1 if the sequence is empty
734742
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
735743
struct llama_context * ctx,
736-
llama_seq_id seq_id);
744+
llama_seq_id seq_id);
737745

738746
// Defragment the KV cache
739747
// This will be applied:

src/llama-context.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2466,6 +2466,15 @@ void llama_kv_self_seq_div(
24662466
kv->seq_div(seq_id, p0, p1, d);
24672467
}
24682468

2469+
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2470+
const auto * kv = ctx->get_kv_self();
2471+
if (!kv) {
2472+
return -1;
2473+
}
2474+
2475+
return kv->seq_pos_min(seq_id);
2476+
}
2477+
24692478
// deprecated
24702479
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24712480
return llama_kv_self_seq_pos_max(ctx, seq_id);
@@ -2474,7 +2483,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24742483
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
24752484
const auto * kv = ctx->get_kv_self();
24762485
if (!kv) {
2477-
return 0;
2486+
return -1;
24782487
}
24792488

24802489
return kv->seq_pos_max(seq_id);

src/llama-kv-cache.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,24 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
320320
}
321321
}
322322

323+
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
324+
llama_pos result = std::numeric_limits<llama_pos>::max();
325+
326+
for (uint32_t i = 0; i < size; ++i) {
327+
if (cells[i].has_seq_id(seq_id)) {
328+
result = std::min(result, cells[i].pos);
329+
}
330+
}
331+
332+
if (result == std::numeric_limits<llama_pos>::max()) {
333+
result = -1;
334+
}
335+
336+
return result;
337+
}
338+
323339
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
324-
llama_pos result = 0;
340+
llama_pos result = -1;
325341

326342
for (uint32_t i = 0; i < size; ++i) {
327343
if (cells[i].has_seq_id(seq_id)) {
@@ -1688,8 +1704,13 @@ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, lla
16881704
kv_swa ->seq_div(seq_id, p0, p1, d);
16891705
}
16901706

1707+
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
1708+
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
1709+
return kv_swa->seq_pos_min(seq_id);
1710+
}
1711+
16911712
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
1692-
return kv_base->seq_pos_max(seq_id);
1713+
return kv_swa->seq_pos_max(seq_id);
16931714
}
16941715

16951716
void llama_kv_cache_unified_iswa::restore() {
@@ -2117,8 +2138,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
21172138
}
21182139
}
21192140

2141+
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
2142+
llama_pos result = std::numeric_limits<llama_pos>::max();
2143+
2144+
for (uint32_t i = 0; i < size; ++i) {
2145+
if (cells[i].has_seq_id(seq_id)) {
2146+
result = std::min(result, cells[i].pos);
2147+
}
2148+
}
2149+
2150+
if (result == std::numeric_limits<llama_pos>::max()) {
2151+
result = -1;
2152+
}
2153+
2154+
return result;
2155+
}
2156+
21202157
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
2121-
llama_pos result = 0;
2158+
llama_pos result = -1;
21222159

21232160
for (uint32_t i = 0; i < size; ++i) {
21242161
if (cells[i].has_seq_id(seq_id)) {

src/llama-kv-cache.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
126126
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
127127
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
128128

129+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
129130
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
130131

131132
//
@@ -335,6 +336,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
335336
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
336337
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
337338

339+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
338340
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
339341

340342
//
@@ -437,6 +439,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
437439
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
438440
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
439441

442+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
440443
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
441444

442445
//

src/llama-memory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class llama_memory_i {
2525
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
2626
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
2727

28+
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
2829
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
2930

3031
virtual bool get_can_edit() const = 0;

0 commit comments

Comments
 (0)