Skip to content

Commit 9473e16

Browse files
committed
cont : take into account n_seq_max and n_batch
ggml-ci
1 parent dfc8fdc commit 9473e16

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

src/llama-kv-cache.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@ uint32_t llama_kv_cache_unified::get_n() const {
559559
return n;
560560
}
561561

562+
uint32_t llama_kv_cache_unified::get_size() const {
563+
return size;
564+
}
565+
562566
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
563567
const int32_t ikv = map_layer_ids.at(il);
564568

@@ -1568,16 +1572,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681572
bool v_trans,
15691573
bool offload,
15701574
uint32_t kv_size,
1575+
uint32_t n_seq_max,
1576+
uint32_t n_batch,
15711577
uint32_t padding) : hparams(model.hparams) {
15721578
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
15731579
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
15741580

1575-
// TODO: provide from the llama_context
1576-
const uint32_t n_seq_max = 1;
1577-
const uint32_t n_batch = hparams.n_swa;
1578-
15791581
const uint32_t kv_size_base = kv_size;
1580-
const uint32_t kv_size_swa = (hparams.n_swa + n_batch)*n_seq_max;
1582+
const uint32_t kv_size_swa = std::min(kv_size, hparams.n_swa*n_seq_max + n_batch);
15811583

15821584
kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
15831585
kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
@@ -1705,7 +1707,7 @@ llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
17051707
}
17061708

17071709
bool llama_kv_cache_unified_iswa::get_can_shift() const {
1708-
return false;
1710+
return kv_base->get_size() == kv_swa->get_size();
17091711
}
17101712

17111713
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {

src/llama-kv-cache.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
160160
//
161161

162162
uint32_t get_n() const;
163+
uint32_t get_size() const;
163164

164165
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
165166
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@@ -301,6 +302,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
301302
bool v_trans,
302303
bool offload,
303304
uint32_t kv_size,
305+
uint32_t n_seq_max,
306+
uint32_t n_batch,
304307
uint32_t padding);
305308

306309
~llama_kv_cache_unified_iswa() = default;

src/llama-model.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13027,6 +13027,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1302713027
!cparams.flash_attn,
1302813028
cparams.offload_kqv,
1302913029
cparams.n_ctx,
13030+
cparams.n_seq_max,
13031+
cparams.n_batch,
1303013032
padding);
1303113033
} else {
1303213034
res = new llama_kv_cache_unified(

0 commit comments

Comments
 (0)