File tree Expand file tree Collapse file tree 2 files changed +18
-4
lines changed Expand file tree Collapse file tree 2 files changed +18
-4
lines changed Original file line number Diff line number Diff line change @@ -366,7 +366,12 @@ bool llama_batch_allocr::init(
366366
367367 for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
368368 const llama_seq_id seq_id = batch.seq_id [i][0 ];
369- pos[i] = p0[seq_id] + i;
369+
370+ pos[i] = p0[seq_id];
371+
372+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
373+ p0[batch.seq_id [i][s]] = pos[i] + 1 ;
374+ }
370375 }
371376
372377 batch.pos = pos.data ();
@@ -397,7 +402,11 @@ bool llama_batch_allocr::init(
397402 const llama_seq_id s0 = batch.seq_id [i][0 ];
398403 const llama_seq_id s1 = batch.seq_id [i][s];
399404
405+ // mark that sequences s1 is couled to s0
400406 seq_cpl[s1][s0] = true ;
407+
408+ // note: the other way around is not necessary for now
409+ // seq_cpl[s0][s1] = true;
401410 }
402411 }
403412 }
@@ -467,6 +476,10 @@ bool llama_batch_allocr::init(
467476 }
468477 }
469478
479+ //
480+ // consistency checks
481+ //
482+
470483 for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
471484 if (seq_pos[s].empty ()) {
472485 continue ;
@@ -478,7 +491,7 @@ bool llama_batch_allocr::init(
478491 }
479492
480493 if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
481- LLAMA_LOG_ERROR (" %s: sequence %d is not contiguous \n " , __func__, s);
494+ LLAMA_LOG_ERROR (" %s: sequence %d positions are not continuous \n " , __func__, s);
482495 return false ;
483496 }
484497 }
Original file line number Diff line number Diff line change @@ -78,12 +78,13 @@ struct llama_sbatch {
7878 llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split = false );
7979};
8080
81- // temporary allocate memory for the input batch if needed
81+ // a helper for sanitizing and fullfilling a batch
8282class llama_batch_allocr {
8383public:
8484 llama_batch_allocr ();
8585
86- // optionally fulfill the batch returned by llama_batch_get_one
86+ // sanitize and auto-gen missing data in the input batch
87+ // memory is optional. if provided will be used to check for sequence continuity
8788 bool init (
8889 const llama_batch & batch_inp,
8990 const llama_vocab & vocab,
You can’t perform that action at this time.
0 commit comments