@@ -88,6 +88,26 @@ enum error_type {
8888 ERROR_TYPE_NOT_SUPPORTED, // custom error
8989};
9090
91+ static bool server_task_type_need_embd (server_task_type task_type) {
92+ switch (task_type) {
93+ case SERVER_TASK_TYPE_EMBEDDING:
94+ case SERVER_TASK_TYPE_RERANK:
95+ return true ;
96+ default :
97+ return false ;
98+ }
99+ }
100+
101+ static bool server_task_type_need_logits (server_task_type task_type) {
102+ switch (task_type) {
103+ case SERVER_TASK_TYPE_COMPLETION:
104+ case SERVER_TASK_TYPE_INFILL:
105+ return true ;
106+ default :
107+ return false ;
108+ }
109+ }
110+
91111struct slot_params {
92112 bool stream = true ;
93113 bool cache_prompt = true ; // remember the prompt to avoid reprocessing all prompt
@@ -1330,6 +1350,14 @@ struct server_slot {
13301350 n_draft_accepted = 0 ;
13311351 }
13321352
1353+ bool need_embd () const {
1354+ return server_task_type_need_embd (task_type);
1355+ }
1356+
1357+ bool need_logits () const {
1358+ return server_task_type_need_logits (task_type);
1359+ }
1360+
13331361 bool can_batch_with (server_slot & other_slot) const {
13341362 return task_type == other_slot.task_type && are_lora_equal (lora, other_slot.lora );
13351363 }
@@ -3095,6 +3123,13 @@ struct server_context {
30953123 continue ;
30963124 }
30973125
3126+ // TODO: support memory-less logits computation
3127+ if (slot.need_logits () && !llama_get_memory (ctx)) {
3128+ slot.release ();
3129+ send_error (slot, " the current context does not logits computation. skipping" , ERROR_TYPE_SERVER);
3130+ continue ;
3131+ }
3132+
30983133 if (!can_split ()) {
30993134 if (slot.n_prompt_tokens > n_ubatch) {
31003135 slot.release ();
@@ -3292,7 +3327,7 @@ struct server_context {
32923327 }
32933328
32943329 // embedding requires all tokens in the batch to be output
3295- const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING || slot. task_type == SERVER_TASK_TYPE_RERANK ;
3330+ const bool need_embd = server_task_type_need_embd ( slot.task_type ) ;
32963331
32973332 common_batch_add (batch, cur_tok, slot.n_past , { slot.id }, need_embd);
32983333 slot.cache_tokens .push_back (cur_tok);
@@ -3348,13 +3383,13 @@ struct server_context {
33483383 if (slot_batched) {
33493384 // apply lora, only need to do it once per batch
33503385 common_set_adapter_lora (ctx, slot_batched->lora );
3351- }
33523386
3353- const bool do_encode = params_base.embedding ;
3387+ llama_set_embeddings (ctx, slot_batched->need_embd ());
3388+ }
33543389
33553390 // pad the batch so that batch.n_tokens >= n_slots
33563391 // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
3357- if (do_encode ) {
3392+ if (llama_get_embeddings (ctx) ) {
33583393 const int n_slots = slots.size ();
33593394
33603395 if (batch.n_tokens < n_slots) {
@@ -4175,11 +4210,6 @@ int main(int argc, char ** argv) {
41754210 oaicompat_type oaicompat) -> void {
41764211 GGML_ASSERT (type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
41774212
4178- if (ctx_server.params_base .embedding ) {
4179- res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
4180- return ;
4181- }
4182-
41834213 auto completion_id = gen_chatcmplid ();
41844214 std::unordered_set<int > task_ids;
41854215 try {
@@ -4434,12 +4464,8 @@ int main(int argc, char ** argv) {
44344464 OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
44354465 };
44364466
4437- const auto handle_chat_completions = [&ctx_server, &res_error, & handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
4467+ const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
44384468 LOG_DBG (" request: %s\n " , req.body .c_str ());
4439- if (ctx_server.params_base .embedding ) {
4440- res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
4441- return ;
4442- }
44434469
44444470 auto body = json::parse (req.body );
44454471 std::vector<raw_buffer> files;
0 commit comments