Skip to content

Commit 2a89bac

Browse files
committed
kv-cache : add ubatch_next()
ggml-ci
1 parent 2c72b74 commit 2a89bac

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,22 +1250,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12501250
int64_t n_outputs_prev = 0;
12511251

12521252
while (sbatch.n_tokens > 0) {
1253-
llama_ubatch ubatch = llama_ubatch();
1254-
1255-
const auto & n_ubatch = cparams.n_ubatch;
1256-
1257-
if (is_recurrent) {
1258-
if (embd_pooled) {
1259-
// Pooled embeddings cannot be split across ubatches (yet)
1260-
ubatch = sbatch.split_seq(cparams.n_ubatch);
1261-
} else {
1262-
// recurrent model architectures are easier to implement
1263-
// with equal-length sequences
1264-
ubatch = sbatch.split_equal(cparams.n_ubatch);
1265-
}
1266-
} else {
1267-
ubatch = sbatch.split_simple(n_ubatch);
1268-
}
1253+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
12691254

12701255
// count the outputs in this u_batch
12711256
{
@@ -1435,7 +1420,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14351420

14361421
// - do not defrag small contexts (i.e. < 2048 tokens)
14371422
// - count the padding towards the number of used tokens
1438-
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f;
1423+
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->padding)/float(kv->n)) : 0.0f;
14391424

14401425
// queue defragmentation for next llama_kv_cache_update
14411426
if (fragmentation > cparams.defrag_thold) {

src/llama-kv-cache.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,14 @@ bool llama_kv_cache_unified::find_slot(
476476
return true;
477477
}
478478

479+
llama_ubatch llama_kv_cache_unified::ubatch_next(
480+
llama_sbatch & sbatch,
481+
uint32_t n_ubatch,
482+
bool embd_pooled) const {
483+
GGML_UNUSED(embd_pooled);
484+
return sbatch.split_simple(n_ubatch);
485+
}
486+
479487
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
480488
// the FA kernels require padding to avoid extra runtime boundary checks
481489
return cparams.flash_attn ? 256u : 32u;
@@ -1539,6 +1547,15 @@ bool llama_kv_cache_recurrent::find_slot(
15391547
return n >= n_seqs;
15401548
}
15411549

1550+
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1551+
if (embd_pooled) {
1552+
// Pooled embeddings cannot be split across ubatches (yet)
1553+
return sbatch.split_seq(n_ubatch);
1554+
}
1555+
1556+
return sbatch.split_equal(n_ubatch);
1557+
}
1558+
15421559
uint32_t llama_kv_cache_recurrent::cell_max() const {
15431560
for (uint32_t i = size; i > 0; --i) {
15441561
const llama_kv_cell & cell = cells[i - 1];

src/llama-kv-cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
struct llama_cparams;
1414
struct llama_hparams;
1515
struct llama_ubatch;
16+
struct llama_sbatch;
1617

1718
struct llama_kv_cache : public llama_memory_i {
1819
// can be used to query data from the model if needed
@@ -44,6 +45,9 @@ struct llama_kv_cache : public llama_memory_i {
4445

4546
virtual bool find_slot(const llama_ubatch & batch) = 0;
4647

48+
// different KV caches require different batch splitting strategies
49+
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
50+
4751
// simulate full cache, used for allocating worst-case compute buffers
4852
virtual void set_full() = 0;
4953

@@ -139,6 +143,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
139143
// to the first cell of the slot.
140144
bool find_slot(const llama_ubatch & batch) override;
141145

146+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147+
142148
static uint32_t get_padding(const llama_cparams & cparams);
143149

144150
// find how many cells are currently in use
@@ -263,6 +269,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
263269
// to the first cell of the slot.
264270
bool find_slot(const llama_ubatch & batch) override;
265271

272+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
273+
266274
// find how many cells are currently in use
267275
uint32_t cell_max() const;
268276

0 commit comments

Comments
 (0)