@@ -402,36 +402,6 @@ bool llama_batch_allocr::init(
402402 }
403403 }
404404
405- for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
406- if (seq_pos[s].empty ()) {
407- continue ;
408- }
409-
410- if (memory && seq_pos_min (s) != memory->seq_pos_max (s) + 1 ) {
411- LLAMA_LOG_ERROR (" %s: sequence %d does not start from the last position stored in the memory\n " , __func__, s);
412- return false ;
413- }
414-
415- if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
416- LLAMA_LOG_ERROR (" %s: sequence %d is not contiguous\n " , __func__, s);
417- return false ;
418- }
419- }
420-
421- if (memory) {
422- for (int32_t s0 = 0 ; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
423- for (int32_t s1 = 0 ; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
424- if (seq_cpl[s0][s1]) {
425- if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
426- memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
427- LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
428- return false ;
429- }
430- }
431- }
432- }
433- }
434-
435405 if (debug > 0 ) {
436406 LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
437407 LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
@@ -497,6 +467,36 @@ bool llama_batch_allocr::init(
497467 }
498468 }
499469
470+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
471+ if (seq_pos[s].empty ()) {
472+ continue ;
473+ }
474+
475+ if (memory && seq_pos_min (s) != memory->seq_pos_max (s) + 1 ) {
476+ LLAMA_LOG_ERROR (" %s: sequence %d does not start from the last position stored in the memory\n " , __func__, s);
477+ return false ;
478+ }
479+
480+ if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
481+ LLAMA_LOG_ERROR (" %s: sequence %d is not contiguous\n " , __func__, s);
482+ return false ;
483+ }
484+ }
485+
486+ if (memory) {
487+ for (int32_t s0 = 0 ; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
488+ for (int32_t s1 = 0 ; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
489+ if (seq_cpl[s0][s1]) {
490+ if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
491+ memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
492+ LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
493+ return false ;
494+ }
495+ }
496+ }
497+ }
498+ }
499+
500500 return true ;
501501}
502502
0 commit comments