@@ -210,7 +210,7 @@ bool llama_batch_allocr::init(
210210        LLAMA_LOG_DEBUG (" %s: input batch info:\n " 
211211
212212        llama_ubatch ubatch {
213-             /* .equal_seqs    =*/ false ,
213+             /* .b_equal_seqs  =*/ false ,
214214            /* .n_tokens     =*/ uint32_t ) batch.n_tokens ,
215215            /* .n_seq_tokens =*/ uint32_t ) 1 ,
216216            /* .n_seqs       =*/ uint32_t ) batch.n_tokens ,
@@ -223,6 +223,7 @@ bool llama_batch_allocr::init(
223223            /* .seq_id_unq   =*/ this ->seq_id_unq .data (),
224224            /* .seq_idx      =*/ this ->seq_idx .data (),
225225            /* .output       =*/ logits ,
226+             /* .data         =*/ 
226227        };
227228
228229        ubatch_print (ubatch, debug);
@@ -366,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
366367    clear ();
367368    split_reset ();
368369
369-     ubatches. emplace_back ();
370+     auto  udata = std::make_shared<llama_ubatch:: data_t > ();
370371
371-     auto  & ubatch = ubatches.back ();
372- 
373-     ubatch.token      .resize (n_tokens);
374-     ubatch.embd       .clear ();
375-     ubatch.pos        .resize (n_tokens);
376-     ubatch.n_seq_id   .resize (n_tokens);
377-     ubatch.seq_id     .resize (n_tokens);
378-     ubatch.seq_id_unq .resize (0 );
379-     ubatch.seq_idx    .resize (LLAMA_MAX_SEQ, -1 );
380-     ubatch.output     .resize (n_tokens);
372+     udata->token      .resize (n_tokens);
373+     udata->embd       .clear ();
374+     udata->pos        .resize (n_tokens);
375+     udata->n_seq_id   .resize (n_tokens);
376+     udata->seq_id     .resize (n_tokens);
377+     udata->seq_id_unq .resize (0 );
378+     udata->seq_idx    .resize (LLAMA_MAX_SEQ, -1 );
379+     udata->output     .resize (n_tokens);
381380
382381    for  (uint32_t  s = 0 ; s < n_seqs; ++s) {
383-         ubatch. seq_idx [s] = s;
384-         ubatch. seq_id_unq .push_back (s);
382+         udata-> seq_idx [s] = s;
383+         udata-> seq_id_unq .push_back (s);
385384    }
386385
387386    llama_ubatch res {
388-         /* .equal_seqs    =*/ true ,
387+         /* .b_equal_seqs  =*/ true ,
389388        /* .n_tokens     =*/ 
390389        /* .n_seq_tokens =*/ 
391390        /* .n_seqs       =*/ 
392391        /* .n_seqs_unq   =*/ 
393392
394-         /* .token        =*/ ubatch. token .data (),
393+         /* .token        =*/ udata-> token .data (),
395394        /* .embd         =*/ nullptr ,
396-         /* .pos          =*/ pos .data (),
397-         /* .n_seq_id     =*/ n_seq_id .data (),
398-         /* .seq_id       =*/ seq_id .data (),
399-         /* .seq_id_unq   =*/ seq_id_unq .data (),
400-         /* .seq_idx      =*/ seq_idx .data (),
401-         /* .output       =*/ output .data (),
395+         /* .pos          =*/ pos .data (),
396+         /* .n_seq_id     =*/ n_seq_id .data (),
397+         /* .seq_id       =*/ seq_id .data (),
398+         /* .seq_id_unq   =*/ seq_id_unq .data (),
399+         /* .seq_idx      =*/ seq_idx .data (),
400+         /* .output       =*/ output .data (),
401+         /* .data         =*/ std::move (udata),
402402    };
403403
404404    return  res;
@@ -439,8 +439,6 @@ void llama_batch_allocr::split_reset() {
439439
440440    used.clear ();
441441    used.resize (get_n_tokens (), false );
442- 
443-     ubatches.clear ();
444442}
445443
446444llama_ubatch llama_batch_allocr::split_simple (uint32_t  n_ubatch) {
@@ -655,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
655653
656654    assert (n_tokens%n_seqs == 0 );
657655
658-     ubatches.emplace_back ();
659- 
660-     auto  & ubatch = ubatches.back ();
656+     auto  udata = std::make_shared<llama_ubatch::data_t >();
661657
662658    const  int32_t  n_pos_cur = batch.embd  ? n_pos_per_embd : 1 ;
663659
664660    const  int64_t  n_embd_all = batch.embd  ? (int64_t ) n_tokens*n_embd : 0 ;
665661    const  int64_t  n_pos_all  =              (int64_t ) n_tokens*n_pos_cur;
666662
667-     ubatch. token      .resize (n_tokens);
668-     ubatch. embd       .resize (n_embd_all);
669-     ubatch. pos        .resize (n_pos_all);
670-     ubatch. n_seq_id   .resize (n_tokens);
671-     ubatch. seq_id     .resize (n_tokens);
672-     ubatch. seq_id_unq .resize (0 );
673-     ubatch. seq_idx    .resize (LLAMA_MAX_SEQ, -1 );
674-     ubatch. output     .resize (n_tokens);
663+     udata-> token      .resize (n_tokens);
664+     udata-> embd       .resize (n_embd_all);
665+     udata-> pos        .resize (n_pos_all);
666+     udata-> n_seq_id   .resize (n_tokens);
667+     udata-> seq_id     .resize (n_tokens);
668+     udata-> seq_id_unq .resize (0 );
669+     udata-> seq_idx    .resize (LLAMA_MAX_SEQ, -1 );
670+     udata-> output     .resize (n_tokens);
675671
676672    seq_set_t  seq_set_unq;
677673
678674    for  (size_t  i = 0 ; i < idxs.size (); ++i) {
679675        if  (batch.token ) {
680-             ubatch. token [i] = batch.token [idxs[i]];
676+             udata-> token [i] = batch.token [idxs[i]];
681677        }
682678
683679        if  (batch.embd ) {
684-             memcpy (ubatch. embd .data () + i*n_embd, batch.embd  + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
680+             memcpy (udata-> embd .data () + i*n_embd, batch.embd  + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
685681        }
686682
687683        for  (int  j = 0 ; j < n_pos_cur; ++j) {
688-             ubatch. pos [j*n_tokens + i] = batch.pos [j*batch.n_tokens  + idxs[i]];
684+             udata-> pos [j*n_tokens + i] = batch.pos [j*batch.n_tokens  + idxs[i]];
689685        }
690686
691-         ubatch. n_seq_id [i] = batch.n_seq_id [idxs[i]];
692-         ubatch. seq_id [i]   = batch.seq_id [idxs[i]];
693-         ubatch. output [i]   = batch.logits [idxs[i]];
687+         udata-> n_seq_id [i] = batch.n_seq_id [idxs[i]];
688+         udata-> seq_id [i]   = batch.seq_id [idxs[i]];
689+         udata-> output [i]   = batch.logits [idxs[i]];
694690
695-         for  (int  s = 0 ; s < ubatch. n_seq_id [i]; ++s) {
696-             seq_set_unq.set (ubatch. seq_id [i][s]);
691+         for  (int  s = 0 ; s < udata-> n_seq_id [i]; ++s) {
692+             seq_set_unq.set (udata-> seq_id [i][s]);
697693        }
698694
699-         if  (ubatch. output [i]) {
695+         if  (udata-> output [i]) {
700696            out_ids.push_back (idxs[i]);
701697        }
702698    }
703699
704700    for  (uint32_t  s = 0 ; s < n_seq_max; ++s) {
705701        if  (seq_set_unq.test (s)) {
706-             ubatch. seq_idx [s] = ubatch. seq_id_unq .size ();
707-             ubatch. seq_id_unq .push_back (s);
702+             udata-> seq_idx [s] = udata-> seq_id_unq .size ();
703+             udata-> seq_id_unq .push_back (s);
708704        }
709705    }
710706
711707    llama_ubatch res {
712-         /* .equal_seqs    =*/ 
708+         /* .b_equal_seqs  =*/ 
713709        /* .n_tokens     =*/ 
714710        /* .n_seq_tokens =*/ 
715711        /* .n_seqs       =*/ 
716-         /* .n_seqs_unq   =*/ uint32_t ) ubatch.seq_id_unq .size (),
717- 
718-         /* .token        =*/ token  ? ubatch.token .data () : nullptr ,
719-         /* .embd         =*/ embd  ? ubatch.embd .data () : nullptr ,
720-         /* .pos          =*/ pos .data (),
721-         /* .n_seq_id     =*/ n_seq_id .data (),
722-         /* .seq_id       =*/ seq_id .data (),
723-         /* .seq_id_unq   =*/ seq_id_unq .data (),
724-         /* .seq_idx      =*/ seq_idx .data (),
725-         /* .output       =*/ output .data (),
712+         /* .n_seqs_unq   =*/ uint32_t ) udata->seq_id_unq .size (),
713+ 
714+         /* .token        =*/ token  ? udata->token .data () : nullptr ,
715+         /* .embd         =*/ embd  ? udata->embd .data () : nullptr ,
716+         /* .pos          =*/ pos .data (),
717+         /* .n_seq_id     =*/ n_seq_id .data (),
718+         /* .seq_id       =*/ seq_id .data (),
719+         /* .seq_id_unq   =*/ seq_id_unq .data (),
720+         /* .seq_idx      =*/ seq_idx .data (),
721+         /* .output       =*/ output .data (),
722+         /* .data         =*/ std::move (udata),
726723    };
727724
728725    if  (debug > 0 ) {
729-         LLAMA_LOG_DEBUG (" %s: added ubatch %d  to split:\n " , ( int ) ubatches. size () -  1 );
726+         LLAMA_LOG_DEBUG (" %s: added ubatch to split:\n " 
730727
731728        ubatch_print (res, debug);
732729    }
@@ -736,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
736733
737734void  llama_batch_allocr::ubatch_print (const  llama_ubatch & ubatch, int  debug) {
738735    if  (debug > 0 ) {
739-         LLAMA_LOG_DEBUG (" %s:   equal_seqs   = %d\n " equal_seqs );
736+         LLAMA_LOG_DEBUG (" %s:   equal_seqs   = %d\n " equal_seqs () );
740737        LLAMA_LOG_DEBUG (" %s:   n_tokens     = %d\n " n_tokens );
741738        LLAMA_LOG_DEBUG (" %s:   n_seq_tokens = %d\n " n_seq_tokens );
742739        LLAMA_LOG_DEBUG (" %s:   n_seqs       = %d\n " n_seqs );
0 commit comments