@@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
728
728
}
729
729
730
730
// note: during encode, we always pass the full sequence starting from pos = 0
731
- if (!batch_allocr->init (batch_inp, model.vocab , nullptr )) {
731
+ if (!batch_allocr->init (batch_inp, model.vocab , nullptr , true )) {
732
732
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
733
733
return -1 ;
734
734
}
@@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
894
894
return -1 ;
895
895
}
896
896
897
- if (!batch_allocr->init (batch_inp, model.vocab , memory.get ())) {
897
+ // when computing embeddings, all tokens are output
898
+ const bool embd_all = cparams.embeddings ;
899
+
900
+ if (!batch_allocr->init (batch_inp, model.vocab , memory.get (), embd_all)) {
898
901
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
899
902
return -1 ;
900
903
}
@@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
911
914
912
915
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
913
916
914
- // this indicates we are doing pooled embedding
915
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
916
-
917
917
const uint32_t n_outputs_all = batch_allocr->get_n_outputs ();
918
918
919
- if (embd_pooled ) {
919
+ if (embd_all ) {
920
920
// require that all tokens are output
921
921
if (n_outputs_all != n_tokens_all) {
922
922
LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n " ,
@@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
945
945
llama_memory_state_ptr mstate;
946
946
947
947
while (true ) {
948
- mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
948
+ mstate = memory->init_batch (batch, cparams.n_ubatch , embd_all );
949
949
if (!mstate) {
950
950
return -2 ;
951
951
}
@@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1058
1058
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
1059
1059
// }
1060
1060
1061
- auto * t_logits = cparams. embeddings ? nullptr : res->get_logits ();
1061
+ auto * t_logits = res->get_logits ();
1062
1062
auto * t_embd = cparams.embeddings ? res->get_embd () : nullptr ;
1063
1063
1064
1064
if (t_embd && res->get_embd_pooled ()) {
@@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1222
1222
const auto n_vocab = vocab.n_tokens ();
1223
1223
const auto n_embd = hparams.n_embd ;
1224
1224
1225
- // TODO: use a per-batch flag for logits presence instead
1226
- bool has_logits = !cparams.embeddings ;
1227
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1225
+ bool has_logits = true ;
1226
+ bool has_embd = cparams.embeddings ;
1228
1227
1229
1228
// TODO: hacky enc-dec support
1230
1229
if (model.arch == LLM_ARCH_T5) {
@@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
2044
2043
2045
2044
n_queued_tokens += n_tokens_all;
2046
2045
2047
- // this indicates we are doing pooled embedding
2048
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2049
-
2050
2046
embd_seq.clear ();
2051
2047
2052
2048
uint32_t n_outputs_all = n_tokens_all;
2053
2049
2054
- auto mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled );
2050
+ auto mstate = memory->init_batch (batch, cparams.n_ubatch , true );
2055
2051
if (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2056
2052
LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
2057
2053
break ;
0 commit comments