|
17 | 17 | // |
18 | 18 |
|
19 | 19 | llama_kv_cache::llama_kv_cache( |
20 | | - const llama_model & model, |
21 | | - layer_filter_cb && filter, |
22 | | - ggml_type type_k, |
23 | | - ggml_type type_v, |
24 | | - bool v_trans, |
25 | | - bool offload, |
26 | | - bool unified, |
27 | | - uint32_t kv_size, |
28 | | - uint32_t n_seq_max, |
29 | | - uint32_t n_pad, |
30 | | - uint32_t n_swa, |
31 | | - llama_swa_type swa_type) : |
| 20 | + const llama_model & model, |
| 21 | + ggml_type type_k, |
| 22 | + ggml_type type_v, |
| 23 | + bool v_trans, |
| 24 | + bool offload, |
| 25 | + bool unified, |
| 26 | + uint32_t kv_size, |
| 27 | + uint32_t n_seq_max, |
| 28 | + uint32_t n_pad, |
| 29 | + uint32_t n_swa, |
| 30 | + llama_swa_type swa_type, |
| 31 | + const layer_filter_cb & filter, |
| 32 | + const layer_reuse_cb & reuse) : |
32 | 33 | model(model), hparams(model.hparams), v_trans(v_trans), |
33 | 34 | n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { |
34 | 35 |
|
35 | 36 | GGML_ASSERT(kv_size % n_pad == 0); |
36 | 37 |
|
37 | | - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] |
38 | | - auto n_layer_cache = hparams.n_layer; |
39 | | - if (model.arch == LLM_ARCH_GEMMA3N) { |
40 | | - n_layer_cache = 20; |
41 | | - } |
42 | | - if (model.arch == LLM_ARCH_GLM4_MOE) { |
43 | | - // GLM-4.5: Only process up to last layer, skip final NextN layer |
44 | | - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; |
45 | | - } |
| 38 | + const uint32_t n_layer_kv = hparams.n_layer_kv(); |
46 | 39 |
|
47 | 40 | // create a context for each buffer type |
48 | 41 | std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; |
49 | 42 | auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { |
50 | 43 | auto it = ctx_map.find(buft); |
51 | 44 | if (it == ctx_map.end()) { |
52 | 45 | ggml_init_params params = { |
53 | | - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), |
| 46 | + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), |
54 | 47 | /*.mem_buffer =*/ NULL, |
55 | 48 | /*.no_alloc =*/ true, |
56 | 49 | }; |
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache( |
97 | 90 | __func__, hparams.n_embd_v_gqa_max()); |
98 | 91 | } |
99 | 92 |
|
100 | | - for (uint32_t il = 0; il < n_layer_cache; il++) { |
| 93 | + for (uint32_t il = 0; il < hparams.n_layer; il++) { |
| 94 | + if (!hparams.has_kv(il)) { |
| 95 | + LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); |
| 96 | + continue; |
| 97 | + } |
| 98 | + |
101 | 99 | if (filter && !filter(il)) { |
102 | | - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); |
| 100 | + LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il); |
103 | 101 | continue; |
104 | 102 | } |
105 | 103 |
|
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache( |
147 | 145 | layers.push_back({ il, k, v, k_stream, v_stream, }); |
148 | 146 | } |
149 | 147 |
|
150 | | - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] |
151 | | - if (model.arch == LLM_ARCH_GEMMA3N) { |
152 | | - LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1); |
| 148 | + if (reuse) { |
| 149 | + LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); |
153 | 150 |
|
154 | | - for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) { |
155 | | - if (filter && !filter(il)) { |
156 | | - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); |
| 151 | + for (uint32_t il = 0; il < hparams.n_layer; il++) { |
| 152 | + const int32_t il_reuse = reuse(il); |
| 153 | + |
| 154 | + if (il_reuse < 0) { |
| 155 | + LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il); |
157 | 156 | continue; |
158 | 157 | } |
159 | 158 |
|
160 | | - const bool is_swa = hparams.is_swa(il); |
161 | | - const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1); |
| 159 | + if (filter && !filter(il)) { |
| 160 | + LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il); |
| 161 | + continue; |
| 162 | + } |
162 | 163 |
|
163 | 164 | GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end()); |
| 165 | + |
164 | 166 | map_layer_ids[il] = map_layer_ids[il_reuse]; |
165 | 167 |
|
166 | | - LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa); |
| 168 | + LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il)); |
167 | 169 | } |
168 | 170 | } |
169 | 171 |
|
|
0 commit comments