@@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) {
758758 t_compute_start_us = ggml_time_us ();
759759 }
760760
761+ // TODO: this clear of the buffer can easily be forgotten - need something better
761762 embd_seq.clear ();
762763
763764 n_queued_tokens += n_tokens;
@@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
940941 }
941942 }
942943
944+ // this indicates we are doing pooled embedding
945+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
946+
947+ int64_t n_outputs_all = 0 ;
948+
949+ // count outputs
950+ for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
951+ n_outputs_all += batch.logits [i] != 0 ;
952+ }
953+
954+ if (embd_pooled) {
955+ // require that all tokens are output
956+ if (n_outputs_all != n_tokens_all) {
957+ LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 " , n_tokens_all = %" PRId64 " )\n " ,
958+ __func__, n_outputs_all, n_tokens_all);
959+ return -1 ;
960+ }
961+ }
962+
943963 GGML_ASSERT (n_tokens_all <= cparams.n_batch );
944964
945965 GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
@@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
949969 }
950970 n_queued_tokens += n_tokens_all;
951971
952- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
953- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
954-
972+ // TODO: this clear of the buffer can easily be forgotten - need something better
955973 embd_seq.clear ();
956974
957- int64_t n_outputs_all = 0 ;
958-
959- // count outputs
960- if (batch.logits && !embd_pooled) {
961- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
962- n_outputs_all += batch.logits [i] != 0 ;
963- }
964- } else if (embd_pooled) {
965- n_outputs_all = n_tokens_all;
966- } else {
967- // keep last output only
968- n_outputs_all = 1 ;
969- }
970-
971975 bool did_optimize = false ;
972976
973977 // handle any pending defrags/shifts
@@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10291033 do {
10301034 const auto & ubatch = mstate->get_ubatch ();
10311035
1032- // count the outputs in this u_batch
1036+ // count the outputs in this ubatch
10331037 {
10341038 int32_t n_outputs_new = 0 ;
10351039
@@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter(
20732077
20742078 n_queued_tokens += n_tokens_all;
20752079
2076- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
2080+ // this indicates we are doing pooled embedding
20772081 const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
20782082
20792083 embd_seq.clear ();
0 commit comments