@@ -276,15 +276,16 @@ void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool
276276
277277llama_batch_allocr::llama_batch_allocr (struct llama_batch & in_batch, llama_pos p0) {
278278 batch = new llama_batch_ext{
279- /* n_tokens =*/ in_batch.n_tokens ,
280- /* max_tokens =*/ in_batch.n_tokens ,
281- /* is_view =*/ false ,
282- /* tokens =*/ in_batch.token ,
283- /* embd =*/ in_batch.embd ,
284- /* pos =*/ in_batch.pos ,
285- /* n_seq_id =*/ in_batch.n_seq_id ,
286- /* seq_id =*/ in_batch.seq_id ,
287- /* logits =*/ in_batch.logits ,
279+ /* n_tokens =*/ in_batch.n_tokens ,
280+ /* max_tokens =*/ in_batch.n_tokens ,
281+ /* n_pos_per_token =*/ 1 ,
282+ /* is_view =*/ false ,
283+ /* tokens =*/ in_batch.token ,
284+ /* embd =*/ in_batch.embd ,
285+ /* pos =*/ in_batch.pos ,
286+ /* n_seq_id =*/ in_batch.n_seq_id ,
287+ /* seq_id =*/ in_batch.seq_id ,
288+ /* logits =*/ in_batch.logits ,
288289 };
289290 GGML_ASSERT (batch->n_tokens > 0 );
290291 if (!in_batch.pos ) {
@@ -338,17 +339,18 @@ struct llama_batch llama_batch_get_one(
338339 };
339340}
340341
341- static struct llama_batch_ext * llama_batch_ext_init_impl (int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
342+ static struct llama_batch_ext * llama_batch_ext_init_impl (int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max, int32_t n_pos_per_token ) {
342343 llama_batch_ext * batch = new llama_batch_ext{
343- /* n_tokens =*/ 0 ,
344- /* max_tokens =*/ n_tokens_alloc,
345- /* is_view =*/ false ,
346- /* tokens =*/ nullptr ,
347- /* embd =*/ nullptr ,
348- /* pos =*/ nullptr ,
349- /* n_seq_id =*/ nullptr ,
350- /* seq_id =*/ nullptr ,
351- /* logits =*/ nullptr ,
344+ /* n_tokens =*/ 0 ,
345+ /* max_tokens =*/ n_tokens_alloc,
346+ /* n_pos_per_token =*/ n_pos_per_token,
347+ /* is_view =*/ false ,
348+ /* tokens =*/ nullptr ,
349+ /* embd =*/ nullptr ,
350+ /* pos =*/ nullptr ,
351+ /* n_seq_id =*/ nullptr ,
352+ /* seq_id =*/ nullptr ,
353+ /* logits =*/ nullptr ,
352354 };
353355
354356 if (n_embd) {
@@ -371,7 +373,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
371373}
372374
373375struct llama_batch_ext * llama_batch_ext_init (struct llama_context * ctx) {
374- return llama_batch_ext_init_impl (llama_n_batch (ctx), 0 , llama_n_seq_max (ctx));
376+ int32_t n_pos_per_token = llama_n_pos_per_token (llama_get_model (ctx));
377+ return llama_batch_ext_init_impl (llama_n_batch (ctx), 0 , llama_n_seq_max (ctx), n_pos_per_token);
375378}
376379
377380struct llama_batch_ext * llama_batch_ext_init_from_embd (
@@ -381,10 +384,10 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd(
381384 size_t n_embd,
382385 const llama_pos * pos,
383386 llama_seq_id seq_id) {
384- auto model = llama_get_model (ctx);
385- struct llama_batch_ext * batch = llama_batch_ext_init_impl (n_tokens, n_embd, 1 );
387+ int32_t n_pos_per_token = llama_n_pos_per_token ( llama_get_model (ctx) );
388+ struct llama_batch_ext * batch = llama_batch_ext_init_impl (n_tokens, n_embd, 1 , n_pos_per_token );
386389 memcpy (batch->embd , embd, n_tokens * n_embd * sizeof (float ));
387- memcpy (batch->pos , pos, n_tokens * llama_n_pos_per_token (model) * sizeof (llama_pos));
390+ memcpy (batch->pos , pos, n_tokens * n_pos_per_token * sizeof (llama_pos));
388391 for (size_t i = 0 ; i < n_tokens; i++) {
389392 batch->n_seq_id [i] = 1 ;
390393 batch->seq_id [i][0 ] = seq_id;
@@ -411,12 +414,16 @@ int32_t llama_batch_ext_add_text(
411414 }
412415 const int32_t output_id = batch->n_tokens ;
413416 batch->token [output_id] = token;
414- batch->pos [output_id] = pos;
417+ batch->n_seq_id [output_id] = n_seq_ids;
418+ batch->logits [output_id] = output;
419+ for (int32_t i = 0 ; i < batch->n_pos_per_token ; i++) {
420+ // TODO: this is only used by qwen2vl for now, and text tokens only have 3 pos, the last is set to 0; we should improve this code in the future
421+ batch->pos [output_id * batch->n_pos_per_token + i] = i < 3 ? pos : 0 ;
422+ }
415423 batch->n_seq_id [output_id] = n_seq_ids;
416424 for (size_t j = 0 ; j < n_seq_ids; j++) {
417425 batch->seq_id [batch->n_tokens ][j] = seq_ids[j];
418426 }
419- batch->logits [output_id] = output;
420427 batch->n_tokens ++;
421428 return output_id;
422429}
@@ -461,15 +468,16 @@ struct llama_batch_ext * llama_batch_ext_get_view(
461468 return nullptr ; // not yet supported
462469 }
463470 llama_batch_ext * batch_view = new llama_batch_ext{
464- /* n_tokens =*/ n_tokens,
465- /* max_tokens =*/ n_tokens,
466- /* is_view =*/ true ,
467- /* tokens =*/ batch->token + offset,
468- /* embd =*/ nullptr ,
469- /* pos =*/ batch->pos + offset,
470- /* n_seq_id =*/ batch->n_seq_id + offset,
471- /* seq_id =*/ batch->seq_id + offset,
472- /* logits =*/ batch->logits + offset,
471+ /* n_tokens =*/ n_tokens,
472+ /* max_tokens =*/ n_tokens,
473+ /* n_pos_per_token =*/ batch->n_pos_per_token ,
474+ /* is_view =*/ true ,
475+ /* tokens =*/ batch->token + offset,
476+ /* embd =*/ nullptr ,
477+ /* pos =*/ batch->pos + offset * batch->n_pos_per_token ,
478+ /* n_seq_id =*/ batch->n_seq_id + offset,
479+ /* seq_id =*/ batch->seq_id + offset,
480+ /* logits =*/ batch->logits + offset,
473481 };
474482 return batch_view;
475483}
0 commit comments