@@ -23,13 +23,14 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
2323}
2424
2525llama_kv_cache_unified::llama_kv_cache_unified (
26- const llama_model & model,
27- ggml_type type_k,
28- ggml_type type_v,
29- bool v_trans,
30- bool offload,
31- uint32_t kv_size,
32- uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
26+ const llama_model & model,
27+ layer_filter_cb && filter,
28+ ggml_type type_k,
29+ ggml_type type_v,
30+ bool v_trans,
31+ bool offload,
32+ uint32_t kv_size,
33+ uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
3334 has_shift = false ;
3435 can_shift = true ;
3536
@@ -73,6 +74,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7374 cells.resize (kv_size);
7475
7576 for (uint32_t il = 0 ; il < hparams.n_layer ; il++) {
77+ if (filter && !filter (il)) {
78+ LLAMA_LOG_DEBUG (" %s: layer %3d: skipped\n " , __func__, il);
79+ continue ;
80+ }
81+
7682 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
7783 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
7884
@@ -1482,6 +1488,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
14821488bool llama_kv_cache_unified::state_read_data (llama_io_read_i & io, uint32_t cell_count) {
14831489 uint32_t v_trans;
14841490 uint32_t n_layer;
1491+
14851492 io.read_to (&v_trans, sizeof (v_trans));
14861493 io.read_to (&n_layer, sizeof (n_layer));
14871494
0 commit comments