@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
2727 const llama_vocab & vocab,
2828 const llama_memory_i * memory,
2929 uint32_t n_embd,
30+ uint32_t n_seq_max,
3031 bool output_all) {
3132 clear ();
3233
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
4041 // validate input batch
4142 //
4243
44+ if (n_seq_max > LLAMA_MAX_SEQ) {
45+ LLAMA_LOG_ERROR (" %s: n_seq_max = %d > %d\n " , __func__, n_seq_max, LLAMA_MAX_SEQ);
46+ return false ;
47+ }
48+
4349 if (batch.token ) {
4450 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
4551 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
5258 if (batch.seq_id ) {
5359 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
5460 for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
55- if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= LLAMA_MAX_SEQ )) {
56- LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], LLAMA_MAX_SEQ );
61+ if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= (llama_seq_id) n_seq_max )) {
62+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], (llama_seq_id) n_seq_max );
5763 return false ;
5864 }
5965 }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
8692
8793 // initialize the starting position for each sequence based on the positions in the memory
8894 llama_pos p0[LLAMA_MAX_SEQ];
89- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
95+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
9096 if (!memory) {
9197 // if no memory -> start from 0
9298 p0[s] = 0 ;
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
143149 // compute stats
144150 //
145151
146- this ->n_embd = n_embd;
152+ this ->n_embd = n_embd;
153+ this ->n_seq_max = n_seq_max;
147154
148155 // count the outputs in this batch
149156 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
189196 seq_set_map[cur].push_back (i);
190197 }
191198
192- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
199+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
193200 if (seq_set_unq.test (s)) {
194201 seq_idx[s] = seq_id_unq.size ();
195202 seq_id_unq.push_back (s);
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
241248 // consistency checks
242249 //
243250
244- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
251+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
245252 if (seq_pos[s].empty ()) {
246253 continue ;
247254 }
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
284291 }
285292
286293 if (memory) {
287- for (int32_t s0 = 0 ; s0 < LLAMA_MAX_SEQ ; ++s0) {
288- for (int32_t s1 = 0 ; s1 < LLAMA_MAX_SEQ ; ++s1) {
294+ for (uint32_t s0 = 0 ; s0 < n_seq_max ; ++s0) {
295+ for (uint32_t s1 = 0 ; s1 < n_seq_max ; ++s1) {
289296 if (seq_cpl[s0][s1]) {
290297 if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
291298 memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
316323 //
317324 {
318325 seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
319- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
326+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
320327 cur_seq_set[s].set ();
321328 }
322329
323330 llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
324- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
331+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
325332 cur_seq_pos[s] = -1 ;
326333 }
327334
@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
692699 }
693700 }
694701
695- for (int32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
702+ for (uint32_t s = 0 ; s < n_seq_max ; ++s) {
696703 if (seq_set_unq.test (s)) {
697704 ubatch.seq_idx [s] = ubatch.seq_id_unq .size ();
698705 ubatch.seq_id_unq .push_back (s);
0 commit comments