@@ -81,6 +81,12 @@ llama_context::llama_context(
8181 }
8282 }
8383
84+ if (!cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
85+ LLAMA_LOG_WARN (" %s: pooling_type is set to %d but embeddings is set to false - disabling pooling\n " , __func__, cparams.pooling_type );
86+
87+ cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
88+ }
89+
8490 if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
8591 cparams.causal_attn = hparams.causal_attn ;
8692 } else {
@@ -728,7 +734,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728734 }
729735
730736 // note: during encode, we always pass the full sequence starting from pos = 0
731- if (!batch_allocr->init (batch_inp, model.vocab , nullptr )) {
737+ if (!batch_allocr->init (batch_inp, model.vocab , nullptr , true )) {
732738 LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
733739 return -1 ;
734740 }
@@ -894,7 +900,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894900 return -1 ;
895901 }
896902
897- if (!batch_allocr->init (batch_inp, model.vocab , memory.get ())) {
903+ // when computing embeddings, all tokens are output
904+ const bool embd_all = cparams.embeddings ;
905+
906+ if (!batch_allocr->init (batch_inp, model.vocab , memory.get (), embd_all)) {
898907 LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
899908 return -1 ;
900909 }
@@ -911,12 +920,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911920
912921 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
913922
914- // this indicates we are doing pooled embedding
915- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
916-
917923 const uint32_t n_outputs_all = batch_allocr->get_n_outputs ();
918924
919- if (embd_pooled ) {
925+ if (embd_all ) {
920926 // require that all tokens are output
921927 if (n_outputs_all != n_tokens_all) {
922928 LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " ,
@@ -945,7 +951,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945951 llama_memory_state_ptr mstate;
946952
947953 while (true ) {
948- mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
954+ mstate = memory->init_batch (batch, cparams.n_ubatch , embd_all );
949955 if (!mstate) {
950956 return -2 ;
951957 }
@@ -1058,7 +1064,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10581064 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
10591065 // }
10601066
1061- auto * t_logits = cparams. embeddings ? nullptr : res->get_logits ();
1067+ auto * t_logits = res->get_logits ();
10621068 auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
10631069
10641070 if (t_embd && res->get_embd_pooled ()) {
@@ -1222,9 +1228,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12221228 const auto n_vocab = vocab.n_tokens ();
12231229 const auto n_embd = hparams.n_embd ;
12241230
1225- // TODO: use a per-batch flag for logits presence instead
1226- bool has_logits = !cparams.embeddings ;
1227- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1231+ bool has_logits = true ;
1232+ bool has_embd = cparams.embeddings ;
12281233
12291234 // TODO: hacky enc-dec support
12301235 if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2049,11 @@ void llama_context::opt_epoch_iter(
20442049
20452050 n_queued_tokens += n_tokens_all;
20462051
2047- // this indicates we are doing pooled embedding
2048- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2049-
20502052 embd_seq.clear ();
20512053
20522054 uint32_t n_outputs_all = n_tokens_all;
20532055
2054- auto mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
2056+ auto mstate = memory->init_batch (batch, cparams.n_ubatch , true );
20552057 if (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
20562058 LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
20572059 break ;
0 commit comments