@@ -320,8 +320,24 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
320320 }
321321}
322322
323+ llama_pos llama_kv_cache_unified::seq_pos_min (llama_seq_id seq_id) const {
324+ llama_pos result = std::numeric_limits<llama_pos>::max ();
325+
326+ for (uint32_t i = 0 ; i < size; ++i) {
327+ if (cells[i].has_seq_id (seq_id)) {
328+ result = std::min (result, cells[i].pos );
329+ }
330+ }
331+
332+ if (result == std::numeric_limits<llama_pos>::max ()) {
333+ result = -1 ;
334+ }
335+
336+ return result;
337+ }
338+
323339llama_pos llama_kv_cache_unified::seq_pos_max (llama_seq_id seq_id) const {
324- llama_pos result = 0 ;
340+ llama_pos result = - 1 ;
325341
326342 for (uint32_t i = 0 ; i < size; ++i) {
327343 if (cells[i].has_seq_id (seq_id)) {
@@ -1688,8 +1704,13 @@ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, lla
16881704 kv_swa ->seq_div (seq_id, p0, p1, d);
16891705}
16901706
1707+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min (llama_seq_id seq_id) const {
1708+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
1709+ return kv_swa->seq_pos_min (seq_id);
1710+ }
1711+
16911712llama_pos llama_kv_cache_unified_iswa::seq_pos_max (llama_seq_id seq_id) const {
1692- return kv_base ->seq_pos_max (seq_id);
1713+ return kv_swa ->seq_pos_max (seq_id);
16931714}
16941715
16951716void llama_kv_cache_unified_iswa::restore () {
@@ -2117,8 +2138,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
21172138 }
21182139}
21192140
2141+ llama_pos llama_kv_cache_recurrent::seq_pos_min (llama_seq_id seq_id) const {
2142+ llama_pos result = std::numeric_limits<llama_pos>::max ();
2143+
2144+ for (uint32_t i = 0 ; i < size; ++i) {
2145+ if (cells[i].has_seq_id (seq_id)) {
2146+ result = std::min (result, cells[i].pos );
2147+ }
2148+ }
2149+
2150+ if (result == std::numeric_limits<llama_pos>::max ()) {
2151+ result = -1 ;
2152+ }
2153+
2154+ return result;
2155+ }
2156+
21202157llama_pos llama_kv_cache_recurrent::seq_pos_max (llama_seq_id seq_id) const {
2121- llama_pos result = 0 ;
2158+ llama_pos result = - 1 ;
21222159
21232160 for (uint32_t i = 0 ; i < size; ++i) {
21242161 if (cells[i].has_seq_id (seq_id)) {
0 commit comments