|
17 | 17 | //
|
18 | 18 |
|
19 | 19 | llama_kv_cache::llama_kv_cache(
|
20 |
| - const llama_model & model, |
21 |
| - layer_filter_cb && filter, |
22 |
| - ggml_type type_k, |
23 |
| - ggml_type type_v, |
24 |
| - bool v_trans, |
25 |
| - bool offload, |
26 |
| - bool unified, |
27 |
| - uint32_t kv_size, |
28 |
| - uint32_t n_seq_max, |
29 |
| - uint32_t n_pad, |
30 |
| - uint32_t n_swa, |
31 |
| - llama_swa_type swa_type) : |
| 20 | + const llama_model & model, |
| 21 | + ggml_type type_k, |
| 22 | + ggml_type type_v, |
| 23 | + bool v_trans, |
| 24 | + bool offload, |
| 25 | + bool unified, |
| 26 | + uint32_t kv_size, |
| 27 | + uint32_t n_seq_max, |
| 28 | + uint32_t n_pad, |
| 29 | + uint32_t n_swa, |
| 30 | + llama_swa_type swa_type, |
| 31 | + const layer_filter_cb & filter, |
| 32 | + const layer_reuse_cb & reuse) : |
32 | 33 | model(model), hparams(model.hparams), v_trans(v_trans),
|
33 | 34 | n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
34 | 35 |
|
35 | 36 | GGML_ASSERT(kv_size % n_pad == 0);
|
36 | 37 |
|
37 |
| - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] |
38 |
| - auto n_layer_cache = hparams.n_layer; |
39 |
| - if (model.arch == LLM_ARCH_GEMMA3N) { |
40 |
| - n_layer_cache = 20; |
41 |
| - } |
42 |
| - if (model.arch == LLM_ARCH_GLM4_MOE) { |
43 |
| - // GLM-4.5: Only process up to last layer, skip final NextN layer |
44 |
| - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; |
45 |
| - } |
| 38 | + const uint32_t n_layer_kv = hparams.n_layer_kv(); |
46 | 39 |
|
47 | 40 | // create a context for each buffer type
|
48 | 41 | std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
49 | 42 | auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
50 | 43 | auto it = ctx_map.find(buft);
|
51 | 44 | if (it == ctx_map.end()) {
|
52 | 45 | ggml_init_params params = {
|
53 |
| - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), |
| 46 | + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), |
54 | 47 | /*.mem_buffer =*/ NULL,
|
55 | 48 | /*.no_alloc =*/ true,
|
56 | 49 | };
|
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
|
97 | 90 | __func__, hparams.n_embd_v_gqa_max());
|
98 | 91 | }
|
99 | 92 |
|
100 |
| - for (uint32_t il = 0; il < n_layer_cache; il++) { |
| 93 | + for (uint32_t il = 0; il < hparams.n_layer; il++) { |
| 94 | + if (!hparams.has_kv(il)) { |
| 95 | + LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); |
| 96 | + continue; |
| 97 | + } |
| 98 | + |
101 | 99 | if (filter && !filter(il)) {
|
102 |
| - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); |
| 100 | + LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il); |
103 | 101 | continue;
|
104 | 102 | }
|
105 | 103 |
|
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
|
147 | 145 | layers.push_back({ il, k, v, k_stream, v_stream, });
|
148 | 146 | }
|
149 | 147 |
|
150 |
| - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] |
151 |
| - if (model.arch == LLM_ARCH_GEMMA3N) { |
152 |
| - LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1); |
| 148 | + if (reuse) { |
| 149 | + LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); |
153 | 150 |
|
154 |
| - for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) { |
155 |
| - if (filter && !filter(il)) { |
156 |
| - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); |
| 151 | + for (uint32_t il = 0; il < hparams.n_layer; il++) { |
| 152 | + const int32_t il_reuse = reuse(il); |
| 153 | + |
| 154 | + if (il_reuse < 0) { |
| 155 | + LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il); |
157 | 156 | continue;
|
158 | 157 | }
|
159 | 158 |
|
160 |
| - const bool is_swa = hparams.is_swa(il); |
161 |
| - const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1); |
| 159 | + if (filter && !filter(il)) { |
| 160 | + LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il); |
| 161 | + continue; |
| 162 | + } |
162 | 163 |
|
163 | 164 | GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
|
| 165 | + |
164 | 166 | map_layer_ids[il] = map_layer_ids[il_reuse];
|
165 | 167 |
|
166 |
| - LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa); |
| 168 | + LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il)); |
167 | 169 | }
|
168 | 170 | }
|
169 | 171 |
|
|
0 commit comments