@@ -251,7 +251,7 @@ llama_context::llama_context(
251251 }
252252
253253 // reserve worst-case graph
254- if (!hparams.vocab_only ) {
254+ if (!hparams.vocab_only && memory ) {
255255 const uint32_t n_seqs = 1 ; // TODO: worst-case number of sequences
256256 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
257257
@@ -700,6 +700,8 @@ int llama_context::encode(llama_batch & inp_batch) {
700700 t_compute_start_us = ggml_time_us ();
701701 }
702702
703+ embd_seq.clear ();
704+
703705 n_queued_tokens += n_tokens;
704706
705707 const int64_t n_embd = hparams.n_embd ;
@@ -761,12 +763,12 @@ int llama_context::encode(llama_batch & inp_batch) {
761763 ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend (sched.get (), t_embd);
762764 GGML_ASSERT (backend_embd != nullptr );
763765
764- GGML_ASSERT (embd != nullptr );
765-
766766 switch (cparams.pooling_type ) {
767767 case LLAMA_POOLING_TYPE_NONE:
768768 {
769769 // extract token embeddings
770+ GGML_ASSERT (embd != nullptr );
771+
770772 GGML_ASSERT (n_tokens*n_embd <= (int64_t ) embd_size);
771773 ggml_backend_tensor_get_async (backend_embd, t_embd, embd, 0 , n_tokens*n_embd*sizeof (float ));
772774 } break ;
@@ -791,11 +793,18 @@ int llama_context::encode(llama_batch & inp_batch) {
791793 } break ;
792794 case LLAMA_POOLING_TYPE_RANK:
793795 {
794- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
795- // wait for an encoder model that requires this pooling type in order to test it
796- // https://github.com/ggerganov/llama.cpp/pull/9510
797- GGML_ABORT (" RANK pooling not implemented yet" );
798- }
796+ // extract the rerank score - a single float per sequence
797+ auto & embd_seq_out = embd_seq;
798+
799+ for (uint32_t s = 0 ; s < ubatch.n_seqs ; ++s) {
800+ const llama_seq_id seq_id = ubatch.seq_id [s][0 ];
801+ if (embd_seq_out.find (seq_id) != embd_seq_out.end ()) {
802+ continue ;
803+ }
804+ embd_seq_out[seq_id].resize (1 );
805+ ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (seq_id)*sizeof (float ), sizeof (float ));
806+ }
807+ } break ;
799808 case LLAMA_POOLING_TYPE_UNSPECIFIED:
800809 {
801810 GGML_ABORT (" unknown pooling type" );
@@ -833,6 +842,11 @@ int llama_context::encode(llama_batch & inp_batch) {
833842}
834843
835844int llama_context::decode (llama_batch & inp_batch) {
845+ if (!memory) {
846+ LLAMA_LOG_WARN (" %s: cannot decode batches with this context (use llama_encode() instead)\n " , __func__);
847+ return encode (inp_batch);
848+ }
849+
836850 if (inp_batch.n_tokens == 0 ) {
837851 LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
838852 return -1 ;
0 commit comments