@@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
309309// interface implementation
310310//
311311
312- struct llama_batch llama_batch_get_one (
312+ struct llama_batch * llama_batch_get_one (
313313 llama_token * tokens,
314314 int32_t n_tokens) {
315- return {
315+ return new llama_batch {
316316 /* n_tokens =*/ n_tokens,
317317 /* tokens =*/ tokens,
318318 /* embd =*/ nullptr ,
@@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one(
323323 };
324324}
325325
326- struct llama_batch llama_batch_init (int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327- llama_batch batch = {
326+ static struct llama_batch * llama_batch_init_impl (int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327+ llama_batch * batch = new llama_batch {
328328 /* n_tokens =*/ 0 ,
329329 /* tokens =*/ nullptr ,
330330 /* embd =*/ nullptr ,
@@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
335335 };
336336
337337 if (embd) {
338- batch. embd = (float *) malloc (sizeof (float ) * n_tokens_alloc * embd);
338+ batch-> embd = (float *) malloc (sizeof (float ) * n_tokens_alloc * embd);
339339 } else {
340- batch. token = (llama_token *) malloc (sizeof (llama_token) * n_tokens_alloc);
340+ batch-> token = (llama_token *) malloc (sizeof (llama_token) * n_tokens_alloc);
341341 }
342342
343- batch. pos = (llama_pos *) malloc (sizeof (llama_pos) * n_tokens_alloc);
344- batch. n_seq_id = (int32_t *) malloc (sizeof (int32_t ) * n_tokens_alloc);
345- batch. seq_id = (llama_seq_id **) malloc (sizeof (llama_seq_id *) * (n_tokens_alloc + 1 ));
343+ batch-> pos = (llama_pos *) malloc (sizeof (llama_pos) * n_tokens_alloc);
344+ batch-> n_seq_id = (int32_t *) malloc (sizeof (int32_t ) * n_tokens_alloc);
345+ batch-> seq_id = (llama_seq_id **) malloc (sizeof (llama_seq_id *) * (n_tokens_alloc + 1 ));
346346 for (int i = 0 ; i < n_tokens_alloc; ++i) {
347- batch. seq_id [i] = (llama_seq_id *) malloc (sizeof (llama_seq_id) * n_seq_max);
347+ batch-> seq_id [i] = (llama_seq_id *) malloc (sizeof (llama_seq_id) * n_seq_max);
348348 }
349- batch. seq_id [n_tokens_alloc] = nullptr ;
349+ batch-> seq_id [n_tokens_alloc] = nullptr ;
350350
351- batch. logits = (int8_t *) malloc (sizeof (int8_t ) * n_tokens_alloc);
351+ batch-> logits = (int8_t *) malloc (sizeof (int8_t ) * n_tokens_alloc);
352352
353353 return batch;
354354}
355355
356- void llama_batch_free (struct llama_batch batch) {
357- if (batch.token ) free (batch.token );
358- if (batch.embd ) free (batch.embd );
359- if (batch.pos ) free (batch.pos );
360- if (batch.n_seq_id ) free (batch.n_seq_id );
361- if (batch.seq_id ) {
362- for (int i = 0 ; batch.seq_id [i] != nullptr ; ++i) {
363- free (batch.seq_id [i]);
356+ struct llama_batch * llama_batch_init (int32_t n_tokens_alloc, int32_t n_seq_max) {
357+ return llama_batch_init_impl (n_tokens_alloc, 0 , n_seq_max);
358+ }
359+
360+ struct llama_batch * llama_batch_init_from_embd (
361+ float * embd,
362+ size_t n_embd,
363+ int32_t pos0,
364+ int32_t seq_id) {
365+ struct llama_batch * batch = llama_batch_init_impl (0 , n_embd, 1 );
366+ memcpy (batch->embd , embd, n_embd * sizeof (float ));
367+ for (int32_t i = 0 ; i < n_embd; i++) {
368+ batch->pos [i] = pos0 + i;
369+ batch->n_seq_id [i] = 1 ;
370+ batch->seq_id [i][0 ] = seq_id;
371+ }
372+ }
373+
374+ int32_t llama_batch_add_text (
375+ struct llama_batch * batch,
376+ llama_token * tokens,
377+ size_t n_tokens,
378+ int32_t pos0,
379+ int32_t * seq_ids,
380+ size_t n_seq_ids) {
381+ if (batch->n_tokens + n_tokens > batch->n_tokens ) {
382+ return -1 ;
383+ }
384+ if (batch->embd ) {
385+ return -2 ;
386+ }
387+ for (int32_t i = 0 ; i < n_tokens; i++) {
388+ batch->token [batch->n_tokens + i] = tokens[i];
389+ batch->pos [batch->n_tokens + i] = pos0 + i;
390+ batch->n_seq_id [batch->n_tokens + i] = n_seq_ids;
391+ for (int32_t j = 0 ; j < n_seq_ids; j++) {
392+ batch->seq_id [batch->n_tokens + i][j] = seq_ids[j];
393+ }
394+ }
395+ }
396+
397+ int32_t llama_batch_add_text (
398+ struct llama_batch * batch,
399+ llama_token * tokens,
400+ size_t n_tokens,
401+ int32_t pos0,
402+ int32_t seq_id) {
403+ std::array<int32_t , 1 > seq_ids = { seq_id };
404+ return llama_batch_add_text (batch, tokens, n_tokens, pos0, seq_ids.data (), seq_ids.size ());
405+ }
406+
407+ int32_t llama_batch_set_logits (
408+ struct llama_batch * batch,
409+ int32_t pos,
410+ int32_t seq_id) {
411+ for (int32_t i = 0 ; i < batch->n_tokens ; i++) {
412+ // find the token having seq_id
413+ for (int32_t j = 0 ; j < batch->n_seq_id [i]; j++) {
414+ if (batch->seq_id [i][j] == seq_id) {
415+ // found the sequence
416+ if (pos == -1 || pos == batch->pos [i]) {
417+ batch->logits [i] = true ;
418+ break ;
419+ }
420+ }
421+ }
422+ }
423+ }
424+
425+ void llama_batch_clear (struct llama_batch * batch) {
426+ batch->n_tokens = 0 ;
427+ }
428+
429+ void llama_batch_free (struct llama_batch * batch) {
430+ if (batch->token ) free (batch->token );
431+ if (batch->embd ) free (batch->embd );
432+ if (batch->pos ) free (batch->pos );
433+ if (batch->n_seq_id ) free (batch->n_seq_id );
434+ if (batch->seq_id ) {
435+ for (int i = 0 ; batch->seq_id [i] != nullptr ; ++i) {
436+ free (batch->seq_id [i]);
364437 }
365- free (batch. seq_id );
438+ free (batch-> seq_id );
366439 }
367- if (batch.logits ) free (batch.logits );
440+ if (batch->logits ) free (batch->logits );
441+ delete batch;
368442}
0 commit comments