|
13 | 13 | struct llama_cparams; |
14 | 14 | struct llama_hparams; |
15 | 15 | struct llama_ubatch; |
| 16 | +struct llama_sbatch; |
16 | 17 |
|
17 | 18 | struct llama_kv_cache : public llama_memory_i { |
18 | 19 | // can be used to query data from the model if needed |
@@ -44,6 +45,9 @@ struct llama_kv_cache : public llama_memory_i { |
44 | 45 |
|
45 | 46 | virtual bool find_slot(const llama_ubatch & batch) = 0; |
46 | 47 |
|
| 48 | + // different KV caches require different batch splitting strategies |
| 49 | + virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; |
| 50 | + |
47 | 51 | // simulate full cache, used for allocating worst-case compute buffers |
48 | 52 | virtual void set_full() = 0; |
49 | 53 |
|
@@ -139,6 +143,8 @@ class llama_kv_cache_unified : public llama_kv_cache { |
139 | 143 | // to the first cell of the slot. |
140 | 144 | bool find_slot(const llama_ubatch & batch) override; |
141 | 145 |
|
| 146 | + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; |
| 147 | + |
142 | 148 | static uint32_t get_padding(const llama_cparams & cparams); |
143 | 149 |
|
144 | 150 | // find how many cells are currently in use |
@@ -263,6 +269,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { |
263 | 269 | // to the first cell of the slot. |
264 | 270 | bool find_slot(const llama_ubatch & batch) override; |
265 | 271 |
|
| 272 | + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; |
| 273 | + |
266 | 274 | // find how many cells are currently in use |
267 | 275 | uint32_t cell_max() const; |
268 | 276 |
|
|
0 commit comments