33#include " llama-impl.h"
44#include " llama-cparams.h"
55#include " llama-vocab.h"
6+ #include " llama-memory.h"
67
78#include < cassert>
89#include < cstring>
@@ -295,7 +296,10 @@ llama_batch_allocr::llama_batch_allocr() {
295296 }
296297}
297298
298- bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
299+ bool llama_batch_allocr::init (
300+ const llama_batch & batch_inp,
301+ const llama_vocab & vocab,
302+ const llama_memory_i * memory) {
299303 clear ();
300304
301305 batch = batch_inp;
@@ -306,14 +310,6 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
306310 // validate input batch
307311 //
308312
309- // TODO: remove
310- if (!batch.pos ) {
311- if (batch.seq_id ) {
312- LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
313- return false ;
314- }
315- }
316-
317313 if (batch.token ) {
318314 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
319315 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
@@ -338,15 +334,6 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
338334 // auto-generate missing fields
339335 //
340336
341- if (!batch.pos ) {
342- assert (p0 >= 0 );
343- pos.resize (batch.n_tokens );
344- for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
345- pos[i] = p0 + i;
346- }
347- batch.pos = pos.data ();
348- }
349-
350337 if (!batch.n_seq_id ) {
351338 n_seq_id.resize (batch.n_tokens );
352339 for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
@@ -364,6 +351,27 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
364351 batch.seq_id = seq_id.data ();
365352 }
366353
354+ if (!batch.pos ) {
355+ pos.resize (batch.n_tokens );
356+
357+ // initialize the starting position for each sequence based on the positions in the memory
358+ llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
360+ if (!memory) {
361+ p0[s] = 0 ;
362+ } else {
363+ p0[s] = memory->seq_pos_max (s) + 1 ;
364+ }
365+ }
366+
367+ for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
368+ const llama_seq_id seq_id = batch.seq_id [i][0 ];
369+ pos[i] = p0[seq_id] + i;
370+ }
371+
372+ batch.pos = pos.data ();
373+ }
374+
367375 if (!batch.logits ) {
368376 // by default return the output only for the last token
369377 output.resize (batch.n_tokens );
@@ -379,24 +387,54 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
379387 n_outputs += batch.logits [i] != 0 ;
380388 }
381389
390+ // determine coupled sequences
391+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
382392 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
383393 for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
384394 seq_pos[batch.seq_id [i][s]].insert (batch.pos [i]);
385395
386396 if (s > 0 ) {
387- seq_cpl[batch.seq_id [i][0 ]][batch.seq_id [i][s]] = true ;
397+ const llama_seq_id s0 = batch.seq_id [i][0 ];
398+ const llama_seq_id s1 = batch.seq_id [i][s];
399+
400+ seq_cpl[s1][s0] = true ;
388401 }
389402 }
390403 }
391404
392- // TODO:
393- // - verify that coupled sequences have same "position contexts"
394- // - verify that input sequences are "contiguous" (no position gaps)
395- // - verify that input sequences begin from the last poition currently in the context
405+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
406+ if (seq_pos[s].empty ()) {
407+ continue ;
408+ }
409+
410+ if (memory && seq_pos_min (s) != memory->seq_pos_max (s) + 1 ) {
411+ LLAMA_LOG_ERROR (" %s: sequence %d does not start from the last position stored in the memory\n " , __func__, s);
412+ return false ;
413+ }
414+
415+ if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
416+ LLAMA_LOG_ERROR (" %s: sequence %d is not contiguous\n " , __func__, s);
417+ return false ;
418+ }
419+ }
420+
421+ if (memory) {
422+ for (int32_t s0 = 0 ; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
423+ for (int32_t s1 = 0 ; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
424+ if (seq_cpl[s0][s1]) {
425+ if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
426+ memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
427+ LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
428+ return false ;
429+ }
430+ }
431+ }
432+ }
433+ }
396434
397435 if (debug > 0 ) {
398- LLAMA_LOG_DEBUG (" %s: input batch info (p0 = %d) :\n " , __func__, p0 );
399- LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
436+ LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
437+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
400438 LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
401439 LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
402440 LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
@@ -439,14 +477,21 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
439477 }
440478 LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
441479
442- LLAMA_LOG_DEBUG (" %s: seq_pos = [\n " , __func__);
443- for (int s = 0 ; s < (int ) seq_pos.size (); ++s) {
444- const auto & cur = seq_pos[s];
445- if (cur.empty ()) {
480+ LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
481+ for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
482+ if (seq_pos[s0].empty ()) {
446483 continue ;
447484 }
448485
449- LLAMA_LOG_DEBUG (" %s: %4d: [%4d, %4d]\n " , __func__, s, seq_pos_min (s), seq_pos_max (s));
486+ std::stringstream ss;
487+ for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
488+ if (seq_cpl[s0][s1]) {
489+ ss << s1 << " " ;
490+ }
491+ }
492+
493+ LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
494+ __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
450495 }
451496 LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
452497 }
@@ -485,7 +530,7 @@ void llama_batch_allocr::clear() {
485530 }
486531
487532 for (auto & cur : seq_cpl) {
488- cur.clear ( );
533+ std::fill ( cur.begin (), cur. end (), false );
489534 }
490535}
491536
0 commit comments