@@ -30,13 +30,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030 bool v_trans,
3131 bool offload,
3232 uint32_t kv_size,
33- uint32_t padding,
33+ uint32_t n_seq_max,
34+ uint32_t n_pad,
3435 uint32_t n_swa,
35- llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
36- GGML_ASSERT (kv_size % padding == 0 && " kv_size must be a multiple of padding" );
36+ llama_swa_type swa_type) :
37+ model(model), hparams(model.hparams), v_trans(v_trans),
38+ n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
3739
38- this ->type_k = type_k;
39- this ->type_v = type_v;
40+ GGML_ASSERT (kv_size % n_pad == 0 );
4041
4142 // create a context for each buffer type
4243 std::map<ggml_backend_buffer_type_t , ggml_context *> ctx_map;
@@ -129,8 +130,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
129130 const size_t memory_size_k = size_k_bytes ();
130131 const size_t memory_size_v = size_v_bytes ();
131132
132- LLAMA_LOG_INFO (" %s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n " , __func__,
133- (float )(memory_size_k + memory_size_v) / (1024 .0f * 1024 .0f ), kv_size, (int ) layers.size (),
133+ LLAMA_LOG_INFO (" %s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs ), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n " , __func__,
134+ (float )(memory_size_k + memory_size_v) / (1024 .0f * 1024 .0f ), kv_size, (int ) layers.size (), n_seq_max,
134135 ggml_type_name (type_k), (float )memory_size_k / (1024 .0f * 1024 .0f ),
135136 ggml_type_name (type_v), (float )memory_size_v / (1024 .0f * 1024 .0f ));
136137 }
@@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
442443void llama_kv_cache_unified::defrag_sched (float thold) {
443444 // - do not defrag small contexts (i.e. < 2048 tokens)
444445 // - count the padding towards the number of used tokens
445- const float fragmentation = n >= 2048 ? std::max (0 .0f , 1 .0f - (float (used + padding )/n)) : 0 .0f ;
446+ const float fragmentation = n >= 2048 ? std::max (0 .0f , 1 .0f - (float (used + n_pad )/n)) : 0 .0f ;
446447
447448 // queue defragmentation for next llama_kv_cache_update
448449 if (fragmentation > thold) {
@@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
558559 // a heuristic, to avoid attending the full cache if it is not yet utilized
559560 // after enough generations, the benefit from this heuristic disappears
560561 // if we start defragmenting the cache, the benefit from this will be more important
561- n = std::min (size, std::max (padding , GGML_PAD (cell_max (), padding )));
562+ n = std::min (size, std::max (n_pad , GGML_PAD (cell_max (), n_pad )));
562563
563564#ifdef FIND_SLOT_DEBUG
564565 LLAMA_LOG_WARN (" end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
@@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
567568 return true ;
568569}
569570
570- int32_t llama_kv_cache_unified::get_n_tokens () const {
571- int32_t result = 0 ;
572-
573- for (uint32_t i = 0 ; i < size; i++) {
574- result += cells[i].seq_id .size ();
575- }
576-
577- return result;
578- }
579-
580- int32_t llama_kv_cache_unified::get_used_cells () const {
581- return used;
582- }
583-
584571bool llama_kv_cache_unified::get_can_shift () const {
585572 return true ;
586573}
@@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
802789 }
803790}
804791
805- llama_pos llama_kv_cache_unified::get_pos_max () const {
806- llama_pos pos_max = -1 ;
807-
808- for (const auto & cell : cells) {
809- pos_max = std::max (pos_max, cell.pos );
810- }
811-
812- return pos_max;
813- }
814-
815792size_t llama_kv_cache_unified::total_size () const {
816793 size_t size = 0 ;
817794
@@ -1501,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15011478 llama_seq_id seq_id;
15021479 io.read_to (&seq_id, sizeof (seq_id));
15031480
1504- // TODO: llama_kv_cache_unified should have a notion of max sequences
1505- // if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1506- if (seq_id < 0 ) {
1507- // LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1508- LLAMA_LOG_ERROR (" %s: invalid seq_id, %d is out of range [0, inf)\n " , __func__, seq_id);
1481+ if (seq_id < 0 || (uint32_t ) seq_id >= n_seq_max) {
1482+ LLAMA_LOG_ERROR (" %s: invalid seq_id, %d is out of range [0, %u)\n " , __func__, seq_id, n_seq_max);
15091483 return false ;
15101484 }
15111485
@@ -1655,17 +1629,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16551629 ggml_type type_v,
16561630 bool v_trans,
16571631 bool offload,
1658- uint32_t kv_size,
16591632 bool swa_full,
1633+ uint32_t kv_size,
16601634 uint32_t n_seq_max,
16611635 uint32_t n_batch,
1662- uint32_t padding ) : hparams(model.hparams) {
1636+ uint32_t n_pad ) : hparams(model.hparams) {
16631637 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
16641638 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
16651639
16661640 const uint32_t size_base = kv_size;
16671641
1668- uint32_t size_swa = std::min (size_base, GGML_PAD (hparams.n_swa *n_seq_max + n_batch, padding ));
1642+ uint32_t size_swa = std::min (size_base, GGML_PAD (hparams.n_swa *n_seq_max + n_batch, n_pad ));
16691643
16701644 // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
16711645 if (swa_full) {
@@ -1680,14 +1654,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16801654
16811655 kv_base = std::make_unique<llama_kv_cache_unified>(
16821656 model, std::move (filter_base), type_k, type_v,
1683- v_trans, offload, size_base, padding ,
1657+ v_trans, offload, size_base, n_seq_max, n_pad ,
16841658 0 , LLAMA_SWA_TYPE_NONE);
16851659
16861660 LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, size_swa);
16871661
16881662 kv_swa = std::make_unique<llama_kv_cache_unified>(
16891663 model, std::move (filter_swa), type_k, type_v,
1690- v_trans, offload, size_swa, padding ,
1664+ v_trans, offload, size_swa, n_seq_max, n_pad ,
16911665 hparams.n_swa , hparams.swa_type );
16921666}
16931667
@@ -1810,18 +1784,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
18101784 return res;
18111785}
18121786
1813- int32_t llama_kv_cache_unified_iswa::get_n_tokens () const {
1814- return kv_base->get_n_tokens ();
1815- }
1816-
1817- int32_t llama_kv_cache_unified_iswa::get_used_cells () const {
1818- return kv_base->get_used_cells ();
1819- }
1820-
1821- llama_pos llama_kv_cache_unified_iswa::get_pos_max () const {
1822- return kv_base->get_pos_max ();
1823- }
1824-
18251787bool llama_kv_cache_unified_iswa::get_can_shift () const {
18261788 return kv_base->get_size () == kv_swa->get_size ();
18271789}
@@ -1853,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18531815 ggml_type type_k,
18541816 ggml_type type_v,
18551817 bool offload,
1856- uint32_t kv_size) : hparams(model.hparams) {
1818+ uint32_t kv_size,
1819+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
18571820 const int32_t n_layer = hparams.n_layer ;
18581821
1859- LLAMA_LOG_INFO (" %s: kv_size = %d , type_k = '%s', type_v = '%s', n_layer = %d\n " ,
1860- __func__, kv_size, ggml_type_name (type_k), ggml_type_name (type_v), n_layer);
1822+ LLAMA_LOG_INFO (" %s: kv_size = %u, n_seq_max = %u , type_k = '%s', type_v = '%s', n_layer = %d\n " ,
1823+ __func__, kv_size, n_seq_max, ggml_type_name (type_k), ggml_type_name (type_v), n_layer);
18611824
18621825 head = 0 ;
18631826 size = kv_size;
18641827 used = 0 ;
18651828
1866- this ->type_k = type_k;
1867- this ->type_v = type_v;
1868-
18691829 cells.clear ();
18701830 cells.resize (kv_size);
18711831
@@ -2203,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
22032163 pending.ranges .clear ();
22042164}
22052165
2206- bool llama_kv_cache_recurrent::update (llama_context & lctx ) {
2207- GGML_UNUSED (lctx );
2166+ bool llama_kv_cache_recurrent::update (llama_context & ctx ) {
2167+ GGML_UNUSED (ctx );
22082168 return false ;
22092169}
22102170
@@ -2265,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
22652225 if (seq_id < 0 || (uint32_t ) seq_id >= size) {
22662226 // too big seq_id
22672227 // TODO: would it be possible to resize the cache instead?
2268- LLAMA_LOG_ERROR (" %s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n " , __func__, seq_id, size );
2228+ LLAMA_LOG_ERROR (" %s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n " , __func__, seq_id, n_seq_max );
22692229 return false ;
22702230 }
22712231 if (j > 0 ) {
@@ -2408,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
24082368 return n >= n_seqs;
24092369}
24102370
2411- int32_t llama_kv_cache_recurrent::get_n_tokens () const {
2412- int32_t result = 0 ;
2413-
2414- for (uint32_t i = 0 ; i < size; i++) {
2415- result += cells[i].seq_id .size ();
2416- }
2417-
2418- return result;
2419- }
2420-
2421- int32_t llama_kv_cache_recurrent::get_used_cells () const {
2422- return used;
2423- }
2424-
2425- llama_pos llama_kv_cache_recurrent::get_pos_max () const {
2426- llama_pos pos_max = -1 ;
2427- for (const auto & cell : cells) {
2428- pos_max = std::max (pos_max, cell.pos );
2429- }
2430-
2431- return pos_max;
2432- }
2433-
24342371bool llama_kv_cache_recurrent::get_can_shift () const {
24352372 return false ;
24362373}
0 commit comments