Skip to content

Commit 0123ff3

Browse files
authored
memory : use sequential equal splits for recurrent modules (ggml-org#16442)
1 parent 0a319bb commit 0123ff3

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/llama-memory-hybrid.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
7373
// if all tokens are output, split by sequence
7474
ubatch = balloc.split_seq(n_ubatch);
7575
} else {
76-
ubatch = balloc.split_equal(n_ubatch, false);
76+
// TODO: non-sequential equal split can be done if using unified KV cache
77+
// for simplicity, we always use sequential equal split for now
78+
ubatch = balloc.split_equal(n_ubatch, true);
7779
}
7880

7981
if (ubatch.n_tokens == 0) {

src/llama-memory-recurrent.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
382382
// if all tokens are output, split by sequence
383383
ubatch = balloc.split_seq(n_ubatch);
384384
} else {
385-
ubatch = balloc.split_equal(n_ubatch, false);
385+
// TODO: non-sequential equal split can be done if using unified KV cache
386+
// for simplicity, we always use sequential equal split for now
387+
ubatch = balloc.split_equal(n_ubatch, true);
386388
}
387389

388390
if (ubatch.n_tokens == 0) {

0 commit comments

Comments
 (0)