@@ -902,23 +902,41 @@ int llama_context::decode(const llama_batch & batch_inp) {
902902 const auto & hparams = model.hparams ;
903903
904904 const int32_t n_vocab = vocab.n_tokens ();
905- const int64_t n_embd = hparams.n_embd ;
906905
907- // when computing embeddings, all tokens are output
908- const bool output_all = cparams. embeddings ;
906+ const int64_t n_tokens_all = batch. n_tokens ;
907+ const int64_t n_embd = hparams. n_embd ;
909908
910- if (!balloc->init (batch_inp, vocab, memory.get (), n_embd, output_all)) {
911- LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
912- return -1 ;
909+ GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
910+
911+ // TODO: move the validation to the llama_batch_allocr
912+ if (batch.token ) {
913+ for (int64_t i = 0 ; i < n_tokens_all; ++i) {
914+ if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
915+ LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
916+ return -1 ;
917+ }
918+
919+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
920+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
921+ return -1 ;
922+ }
923+ }
913924 }
914925
915- const uint32_t n_tokens_all = balloc->get_n_tokens ();
916- const uint32_t n_outputs_all = balloc->get_n_outputs ();
926+ // this indicates we are doing pooled embedding
927+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
928+
929+ int64_t n_outputs_all = 0 ;
917930
918- if (output_all) {
931+ // count outputs
932+ for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
933+ n_outputs_all += batch.logits [i] != 0 ;
934+ }
935+
936+ if (embd_pooled) {
919937 // require that all tokens are output
920938 if (n_outputs_all != n_tokens_all) {
921- LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d , n_tokens_all = %d )\n " ,
939+ LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 " , n_tokens_all = %" PRId64 " )\n " ,
922940 __func__, n_outputs_all, n_tokens_all);
923941 return -1 ;
924942 }
@@ -2045,6 +2063,9 @@ void llama_context::opt_epoch_iter(
20452063
20462064 n_queued_tokens += n_tokens_all;
20472065
2066+ // this indicates we are doing pooled embedding
2067+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2068+
20482069 embd_seq.clear ();
20492070
20502071 uint32_t n_outputs_all = n_tokens_all;
0 commit comments