Skip to content

Commit d1d89e0

Browse files
committed
llama-context: add ability to get logits
1 parent 0d92267 commit d1d89e0

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/llama-context.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
731731

732732
const auto & hparams = model.hparams;
733733

734-
const int64_t n_embd = hparams.n_embd;
734+
const int64_t n_embd = hparams.n_embd;
735+
const int32_t n_vocab = model.vocab.n_tokens();
735736

736737
// note: during encode, we always pass the full sequence starting from pos = 0
737738
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
@@ -791,10 +792,22 @@ int llama_context::encode(const llama_batch & batch_inp) {
791792
}
792793
}
793794

795+
auto * t_logits = res->get_logits();
794796
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
795797

798+
// extract logits
799+
if (t_logits && n_outputs > 0) {
800+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
801+
GGML_ASSERT(backend_res != nullptr);
802+
GGML_ASSERT(logits != nullptr);
803+
804+
if (n_outputs) {
805+
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_outputs*n_vocab*sizeof(float));
806+
}
807+
}
808+
796809
// extract embeddings
797-
if (t_embd) {
810+
if (cparams.embeddings && t_embd) {
798811
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
799812
GGML_ASSERT(backend_embd != nullptr);
800813

0 commit comments

Comments
 (0)