Skip to content

Commit 8aee2e7

Browse files
committed
feat: Add split_equal to init(...) signature
This will enable the hybrid cache to control the split type for all children together. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent db9a618 commit 8aee2e7

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

src/llama-kv-cache.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,19 @@ llama_memory_decode_state_ptr llama_kv_cache_unified::init(
352352
const llama_batch & batch,
353353
uint32_t n_ubatch,
354354
bool embd_pooled,
355-
bool logits_all) {
355+
bool logits_all,
356+
bool split_equal) {
356357
GGML_UNUSED(embd_pooled);
357358

358359
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
359360

360361
std::vector<llama_ubatch> ubatches;
361362
while (sbatch.n_tokens > 0) {
362-
ubatches.push_back(sbatch.split_simple(n_ubatch));
363+
if (split_equal) {
364+
ubatches.push_back(sbatch.split_equal(n_ubatch));
365+
} else {
366+
ubatches.push_back(sbatch.split_simple(n_ubatch));
367+
}
363368
}
364369

365370
auto heads = prepare(ubatches);
@@ -1821,17 +1826,24 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
18211826
return kv_swa->seq_pos_max(seq_id);
18221827
}
18231828

1824-
llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
1829+
llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(
1830+
const llama_batch & batch,
1831+
uint32_t n_ubatch,
1832+
bool embd_pooled,
1833+
bool logits_all,
1834+
bool split_equal) {
18251835
GGML_UNUSED(embd_pooled);
18261836

18271837
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
18281838

18291839
std::vector<llama_ubatch> ubatches;
18301840

18311841
while (sbatch.n_tokens > 0) {
1832-
auto ubatch = sbatch.split_simple(n_ubatch);
1833-
1834-
ubatches.push_back(ubatch);
1842+
if (split_equal) {
1843+
ubatches.push_back(sbatch.split_equal(n_ubatch));
1844+
} else {
1845+
ubatches.push_back(sbatch.split_simple(n_ubatch));
1846+
}
18351847
}
18361848

18371849
auto heads_base = kv_base->prepare(ubatches);
@@ -2291,8 +2303,15 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
22912303
return result;
22922304
}
22932305

2294-
llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
2306+
llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(
2307+
const llama_batch & batch,
2308+
uint32_t n_ubatch,
2309+
bool embd_pooled,
2310+
bool logits_all,
2311+
bool split_equal) {
22952312
GGML_UNUSED(embd_pooled);
2313+
// TODO: Should this just be ignored?
2314+
assert(split_equal);
22962315

22972316
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
22982317

src/llama-kv-cache.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ struct llama_kv_cache : public llama_memory_i {
3434
const llama_batch & batch,
3535
uint32_t n_ubatch,
3636
bool embd_pooled,
37-
bool logits_all) = 0;
37+
bool logits_all,
38+
bool split_equal = false) = 0;
3839

3940
// process any pending defrag/shift/etc. operations
4041
// optionally call once before processing a new batch
@@ -112,7 +113,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
112113
const llama_batch & batch,
113114
uint32_t n_ubatch,
114115
bool embd_pooled,
115-
bool logits_all) override;
116+
bool logits_all,
117+
bool split_equal = false) override;
116118

117119
bool update(llama_context & lctx) override;
118120

@@ -289,7 +291,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
289291
const llama_batch & batch,
290292
uint32_t n_ubatch,
291293
bool embd_pooled,
292-
bool logits_all) override;
294+
bool logits_all,
295+
bool split_equal = false) override;
293296

294297
bool update(llama_context & lctx) override;
295298

@@ -360,7 +363,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
360363
const llama_batch & batch,
361364
uint32_t n_ubatch,
362365
bool embd_pooled,
363-
bool logits_all) override;
366+
bool logits_all,
367+
bool split_equal = true) override;
364368

365369
bool update(llama_context & lctx) override;
366370

0 commit comments

Comments
 (0)