@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728728    }
729729
730730    //  note: during encode, we always pass the full sequence starting from pos = 0
731-     if  (!batch_allocr->init (batch_inp, model.vocab , nullptr )) {
731+     if  (!batch_allocr->init (batch_inp, model.vocab , nullptr ,  true )) {
732732        LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " 
733733        return  -1 ;
734734    }
@@ -899,7 +899,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
899899        return  -1 ;
900900    }
901901
902-     if  (!batch_allocr->init (batch_inp, model.vocab , memory.get ())) {
902+     //  when computing embeddings, all tokens are output
903+     const  bool  embd_all = cparams.embeddings ;
904+ 
905+     if  (!batch_allocr->init (batch_inp, model.vocab , memory.get (), embd_all)) {
903906        LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " 
904907        return  -1 ;
905908    }
@@ -916,12 +919,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
916919
917920    GGML_ASSERT ((!batch.token  && batch.embd ) || (batch.token  && !batch.embd )); //  NOLINT
918921
919-     //  this indicates we are doing pooled embedding
920-     const  bool  embd_pooled = cparams.embeddings  && cparams.pooling_type  != LLAMA_POOLING_TYPE_NONE;
921- 
922922    const  uint32_t  n_outputs_all = batch_allocr->get_n_outputs ();
923923
924-     if  (embd_pooled ) {
924+     if  (embd_all ) {
925925        //  require that all tokens are output
926926        if  (n_outputs_all != n_tokens_all) {
927927            LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " 
@@ -950,7 +950,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
950950    llama_memory_context_ptr mctx;
951951
952952    while  (true ) {
953-         mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
953+         mstate = memory->init_batch (batch, cparams.n_ubatch , embd_all );
954954        if  (!mstate) {
955955            return  -2 ;
956956        }
@@ -2052,14 +2052,11 @@ void llama_context::opt_epoch_iter(
20522052
20532053        n_queued_tokens += n_tokens_all;
20542054
2055-         //  this indicates we are doing pooled embedding
2056-         const  bool  embd_pooled = cparams.embeddings  && cparams.pooling_type  != LLAMA_POOLING_TYPE_NONE;
2057- 
20582055        embd_seq.clear ();
20592056
20602057        uint32_t  n_outputs_all = n_tokens_all;
20612058
2062-         auto  mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
2059+         auto  mstate = memory->init_batch (batch, cparams.n_ubatch , true );
20632060        if  (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
20642061            LLAMA_LOG_ERROR (" %s: could not initialize batch\n " 
20652062            break ;
0 commit comments