@@ -88,6 +88,7 @@ bool llama_batch_allocr::init(
8888 llama_pos p0[LLAMA_MAX_SEQ];
8989 for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
9090 if (!memory) {
91+ // if no memory -> start from 0
9192 p0[s] = 0 ;
9293 } else {
9394 p0[s] = memory->seq_pos_max (s) + 1 ;
@@ -99,8 +100,11 @@ bool llama_batch_allocr::init(
99100
100101 pos[i] = p0[seq_id];
101102
103+ // update the starting position for all sequences that are assigned to the this token
102104 for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
103- p0[batch.seq_id [i][s]] = pos[i] + 1 ;
105+ const llama_seq_id seq_id = batch.seq_id [i][s];
106+
107+ p0[seq_id] = pos[i] + 1 ;
104108 }
105109 }
106110
@@ -141,6 +145,7 @@ bool llama_batch_allocr::init(
141145
142146 this ->n_embd = n_embd;
143147
148+ // count the outputs in this batch
144149 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
145150 n_outputs += batch.logits [i] != 0 ;
146151 }
@@ -159,22 +164,23 @@ bool llama_batch_allocr::init(
159164 // mark that sequence s1 is coupled to s0
160165 seq_cpl[s1][s0] = true ;
161166
162- // note: the other way around is not necessary for now
167+ // note: tracking the other way around is not necessary for now
163168 // seq_cpl[s0][s1] = true;
164169 }
165170 }
166171 }
167172
173+ // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
168174 {
169175 seq_set_t seq_set_unq;
170176
171177 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
172178 seq_set_t cur;
173179 for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
174- const llama_seq_id s0 = batch.seq_id [i][s];
180+ const llama_seq_id seq_id = batch.seq_id [i][s];
175181
176- cur.set (s0 );
177- seq_set_unq.set (s0 );
182+ cur .set (seq_id );
183+ seq_set_unq.set (seq_id );
178184 }
179185
180186 seq_set.push_back (cur);
@@ -263,6 +269,15 @@ bool llama_batch_allocr::init(
263269 }
264270 }
265271
272+ // disallow disjoint sequence sets:
273+ //
274+ // invalid: x
275+ // i: 0 1 2 ...
276+ // ---------------------------------------
277+ // seq_id[i][0]: 0 0 1
278+ // seq_id[i][1]: 1 1 2
279+ // seq_id[i][2]: 2
280+ //
266281 {
267282 seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
268283 for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
@@ -368,11 +383,13 @@ void llama_batch_allocr::split_reset() {
368383}
369384
370385llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
386+ // find the first unused token
371387 uint32_t cur_idx = 0 ;
372388 while (cur_idx < used.size () && used[cur_idx]) {
373389 ++cur_idx;
374390 }
375391
392+ // we are done
376393 if (cur_idx >= used.size ()) {
377394 return {};
378395 }
@@ -401,7 +418,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
401418llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
402419 std::vector<seq_set_t > cur_seq_set;
403420
404- // determine the sequence sets participating in this ubatch
421+ // determine the non-overlapping sequence sets participating in this ubatch
405422 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
406423 if (used[i]) {
407424 continue ;
@@ -428,10 +445,12 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
428445
429446 const uint32_t n_seqs = cur_seq_set.size ();
430447
448+ // we are done
431449 if (n_seqs == 0 ) {
432450 return {};
433451 }
434452
453+ // the current batch index of each sequence set
435454 std::vector<int32_t > cur_idx (n_seqs, 0 );
436455
437456 for (uint32_t s = 0 ; s < n_seqs; ++s) {
@@ -440,9 +459,13 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
440459 }
441460 }
442461
462+ // the list of batch indices for each sequence set
463+ // at the end we will concat these to get the final ubatch
443464 std::vector<idx_vec_t > idxs_per_seq (n_seqs);
444465
445466 while (true ) {
467+ // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
468+ // if we haven't reached n_ubatch
446469 bool can_expand = true ;
447470
448471 for (uint32_t s = 0 ; s < n_seqs; ++s) {
@@ -458,6 +481,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
458481
459482 for (uint32_t s = 0 ; s < n_seqs; ++s) {
460483 const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
484+
461485 idxs_per_seq[s].push_back (idx);
462486
463487 used[idx] = true ;
@@ -470,6 +494,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
470494 }
471495 }
472496
497+ // concat the per-sequence-set lists
473498 std::vector<int32_t > idxs;
474499
475500 for (uint32_t s = 0 ; s < n_seqs; ++s) {
@@ -480,15 +505,19 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
480505}
481506
482507llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
508+ // find the first unused token
483509 uint32_t cur_idx = 0 ;
484510 while (cur_idx < used.size () && used[cur_idx]) {
485511 ++cur_idx;
486512 }
487513
514+ // we are done
488515 if (cur_idx >= used.size ()) {
489516 return {};
490517 }
491518
519+ // this is the starting sequence set
520+ // we allow adding tokens only if their sequence set is a subset of the current sequence set
492521 auto cur_seq_set = seq_set[cur_idx];
493522
494523 std::vector<int32_t > idxs;
0 commit comments