@@ -352,14 +352,19 @@ llama_memory_decode_state_ptr llama_kv_cache_unified::init(
352352 const llama_batch & batch,
353353 uint32_t n_ubatch,
354354 bool embd_pooled,
355- bool logits_all) {
355+ bool logits_all,
356+ bool split_equal) {
356357 GGML_UNUSED (embd_pooled);
357358
358359 auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
359360
360361 std::vector<llama_ubatch> ubatches;
361362 while (sbatch.n_tokens > 0 ) {
362- ubatches.push_back (sbatch.split_simple (n_ubatch));
363+ if (split_equal) {
364+ ubatches.push_back (sbatch.split_equal (n_ubatch));
365+ } else {
366+ ubatches.push_back (sbatch.split_simple (n_ubatch));
367+ }
363368 }
364369
365370 auto heads = prepare (ubatches);
@@ -1821,17 +1826,24 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
18211826 return kv_swa->seq_pos_max (seq_id);
18221827}
18231828
1824- llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
1829+ llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init (
1830+ const llama_batch & batch,
1831+ uint32_t n_ubatch,
1832+ bool embd_pooled,
1833+ bool logits_all,
1834+ bool split_equal) {
18251835 GGML_UNUSED (embd_pooled);
18261836
18271837 auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
18281838
18291839 std::vector<llama_ubatch> ubatches;
18301840
18311841 while (sbatch.n_tokens > 0 ) {
1832- auto ubatch = sbatch.split_simple (n_ubatch);
1833-
1834- ubatches.push_back (ubatch);
1842+ if (split_equal) {
1843+ ubatches.push_back (sbatch.split_equal (n_ubatch));
1844+ } else {
1845+ ubatches.push_back (sbatch.split_simple (n_ubatch));
1846+ }
18351847 }
18361848
18371849 auto heads_base = kv_base->prepare (ubatches);
@@ -2291,8 +2303,15 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
22912303 return result;
22922304}
22932305
2294- llama_memory_decode_state_ptr llama_kv_cache_recurrent::init (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
2306+ llama_memory_decode_state_ptr llama_kv_cache_recurrent::init (
2307+ const llama_batch & batch,
2308+ uint32_t n_ubatch,
2309+ bool embd_pooled,
2310+ bool logits_all,
2311+ bool split_equal) {
22952312 GGML_UNUSED (embd_pooled);
2313+ // TODO: Should this just be ignored?
2314+ assert (split_equal);
22962315
22972316 auto sbatch = llama_sbatch (batch, hparams.n_embd , false , logits_all);
22982317
0 commit comments