@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728728    }
729729
730730    //  note: during encode, we always pass the full sequence starting from pos = 0
731-     if  (!batch_allocr->init (batch_inp, model.vocab , nullptr )) {
731+     if  (!batch_allocr->init (batch_inp, model.vocab , nullptr ,  true )) {
732732        LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " 
733733        return  -1 ;
734734    }
@@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894894        return  -1 ;
895895    }
896896
897-     if  (!batch_allocr->init (batch_inp, model.vocab , memory.get ())) {
897+     //  when computing embeddings, all tokens are output
898+     const  bool  embd_all = cparams.embeddings ;
899+ 
900+     if  (!batch_allocr->init (batch_inp, model.vocab , memory.get (), embd_all)) {
898901        LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " 
899902        return  -1 ;
900903    }
@@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911914
912915    GGML_ASSERT ((!batch.token  && batch.embd ) || (batch.token  && !batch.embd )); //  NOLINT
913916
914-     //  this indicates we are doing pooled embedding
915-     const  bool  embd_pooled = cparams.embeddings  && cparams.pooling_type  != LLAMA_POOLING_TYPE_NONE;
916- 
917917    const  uint32_t  n_outputs_all = batch_allocr->get_n_outputs ();
918918
919-     if  (embd_pooled ) {
919+     if  (embd_all ) {
920920        //  require that all tokens are output
921921        if  (n_outputs_all != n_tokens_all) {
922922            LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " 
@@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945945    llama_memory_state_ptr mstate;
946946
947947    while  (true ) {
948-         mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
948+         mstate = memory->init_batch (batch, cparams.n_ubatch , embd_all );
949949        if  (!mstate) {
950950            return  -2 ;
951951        }
@@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10581058        //     ggml_graph_dump_dot(gf, NULL, "llama.dot");
10591059        // }
10601060
1061-         auto  * t_logits = cparams. embeddings  ?  nullptr          :  res->get_logits ();
1061+         auto  * t_logits = res->get_logits ();
10621062        auto  * t_embd   = cparams.embeddings  ? res->get_embd () : nullptr ;
10631063
10641064        if  (t_embd && res->get_embd_pooled ()) {
@@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12221222    const  auto  n_vocab = vocab.n_tokens ();
12231223    const  auto  n_embd  = hparams.n_embd ;
12241224
1225-     //  TODO: use a per-batch flag for logits presence instead
1226-     bool  has_logits = !cparams.embeddings ;
1227-     bool  has_embd   =  cparams.embeddings  && (cparams.pooling_type  == LLAMA_POOLING_TYPE_NONE);
1225+     bool  has_logits = true ;
1226+     bool  has_embd   = cparams.embeddings ;
12281227
12291228    //  TODO: hacky enc-dec support
12301229    if  (model.arch  == LLM_ARCH_T5) {
@@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
20442043
20452044        n_queued_tokens += n_tokens_all;
20462045
2047-         //  this indicates we are doing pooled embedding
2048-         const  bool  embd_pooled = cparams.embeddings  && cparams.pooling_type  != LLAMA_POOLING_TYPE_NONE;
2049- 
20502046        embd_seq.clear ();
20512047
20522048        uint32_t  n_outputs_all = n_tokens_all;
20532049
2054-         auto  mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
2050+         auto  mstate = memory->init_batch (batch, cparams.n_ubatch , true );
20552051        if  (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
20562052            LLAMA_LOG_ERROR (" %s: could not initialize batch\n " 
20572053            break ;
0 commit comments