@@ -559,6 +559,10 @@ uint32_t llama_kv_cache_unified::get_n() const {
559559 return n;
560560}
561561
562+ uint32_t llama_kv_cache_unified::get_size () const {
563+ return size;
564+ }
565+
562566ggml_tensor * llama_kv_cache_unified::get_k (ggml_context * ctx, int32_t il) const {
563567 const int32_t ikv = map_layer_ids.at (il);
564568
@@ -1568,16 +1572,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681572 bool v_trans,
15691573 bool offload,
15701574 uint32_t kv_size,
1575+ uint32_t n_seq_max,
1576+ uint32_t n_batch,
15711577 uint32_t padding) : hparams(model.hparams) {
15721578 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
15731579 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
15741580
1575- // TODO: provide from the llama_context
1576- const uint32_t n_seq_max = 1 ;
1577- const uint32_t n_batch = hparams.n_swa ;
1578-
15791581 const uint32_t kv_size_base = kv_size;
1580- const uint32_t kv_size_swa = ( hparams.n_swa + n_batch)*n_seq_max ;
1582+ const uint32_t kv_size_swa = std::min (kv_size, hparams.n_swa *n_seq_max + n_batch);
15811583
15821584 kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
15831585 kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
@@ -1705,7 +1707,7 @@ llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
17051707}
17061708
17071709bool llama_kv_cache_unified_iswa::get_can_shift () const {
1708- return false ;
1710+ return kv_base-> get_size () == kv_swa-> get_size () ;
17091711}
17101712
17111713void llama_kv_cache_unified_iswa::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
0 commit comments