@@ -54,7 +54,10 @@ enum server_state {
5454};
5555
5656enum server_task_type {
57- SERVER_TASK_TYPE_INFERENCE,
57+ SERVER_TASK_TYPE_COMPLETION,
58+ SERVER_TASK_TYPE_EMBEDDING,
59+ SERVER_TASK_TYPE_RERANK,
60+ SERVER_TASK_TYPE_INFILL,
5861 SERVER_TASK_TYPE_CANCEL,
5962 SERVER_TASK_TYPE_NEXT_RESPONSE,
6063 SERVER_TASK_TYPE_METRICS,
@@ -64,13 +67,6 @@ enum server_task_type {
6467 SERVER_TASK_TYPE_SET_LORA,
6568};
6669
67- enum server_task_inf_type {
68- SERVER_TASK_INF_TYPE_COMPLETION,
69- SERVER_TASK_INF_TYPE_EMBEDDING,
70- SERVER_TASK_INF_TYPE_RERANK,
71- SERVER_TASK_INF_TYPE_INFILL,
72- };
73-
7470// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
7571enum error_type {
7672 ERROR_TYPE_INVALID_REQUEST,
@@ -163,8 +159,7 @@ struct server_task {
163159 int id = -1 ; // to be filled by server_queue
164160 int index = -1 ; // used when there are multiple prompts (batch request)
165161
166- server_task_type type;
167- server_task_inf_type inf_type;
162+ server_task_type type;
168163
169164 // used by SERVER_TASK_TYPE_CANCEL
170165 int id_target = -1 ;
@@ -185,9 +180,7 @@ struct server_task {
185180 // used by SERVER_TASK_TYPE_METRICS
186181 bool metrics_reset_bucket = false ;
187182
188- server_task (
189- server_task_type type,
190- server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {}
183+ server_task (server_task_type type) : type(type) {}
191184
192185 static slot_params params_from_json_cmpl (
193186 const llama_model * model,
@@ -893,6 +886,9 @@ struct server_slot {
893886 int id;
894887 int id_task = -1 ;
895888
889+ // only used for completion/embedding/infill/rerank
890+ server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
891+
896892 llama_batch batch_spec = {};
897893
898894 llama_context * ctx = nullptr ;
@@ -931,8 +927,6 @@ struct server_slot {
931927 llama_tokens cache_tokens;
932928 std::vector<completion_token_output> generated_token_probs;
933929
934- server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
935-
936930 bool has_next_token = true ;
937931 bool has_new_line = false ;
938932 bool truncated = false ;
@@ -972,11 +966,15 @@ struct server_slot {
972966 n_past = 0 ;
973967 n_sent_text = 0 ;
974968 n_sent_token_probs = 0 ;
975- inf_type = SERVER_TASK_INF_TYPE_COMPLETION ;
969+ task_type = SERVER_TASK_TYPE_COMPLETION ;
976970
977971 generated_token_probs.clear ();
978972 }
979973
974+ bool is_non_causal () const {
975+ return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
976+ }
977+
980978 bool has_budget (const common_params & global_params) {
981979 if (params.n_predict == -1 && global_params.n_predict == -1 ) {
982980 return true ; // limitless
@@ -1088,6 +1086,7 @@ struct server_slot {
10881086 {" n_ctx" , n_ctx},
10891087 {" speculative" , can_speculate ()},
10901088 {" is_processing" , is_processing ()},
1089+ {" non_causal" , is_non_causal ()},
10911090 {" params" , params.to_json ()},
10921091 {" prompt" , common_detokenize (ctx, prompt_tokens)},
10931092 {" next_token" ,
@@ -1653,8 +1652,8 @@ struct server_context {
16531652 bool launch_slot_with_task (server_slot & slot, const server_task & task) {
16541653 slot.reset ();
16551654 slot.id_task = task.id ;
1656- slot.inf_type = task.inf_type ;
16571655 slot.index = task.index ;
1656+ slot.task_type = task.type ;
16581657 slot.params = std::move (task.params );
16591658 slot.prompt_tokens = std::move (task.prompt_tokens );
16601659
@@ -2120,7 +2119,10 @@ struct server_context {
21202119
21212120 void process_single_task (server_task task) {
21222121 switch (task.type ) {
2123- case SERVER_TASK_TYPE_INFERENCE:
2122+ case SERVER_TASK_TYPE_COMPLETION:
2123+ case SERVER_TASK_TYPE_INFILL:
2124+ case SERVER_TASK_TYPE_EMBEDDING:
2125+ case SERVER_TASK_TYPE_RERANK:
21242126 {
21252127 const int id_slot = task.id_selected_slot ;
21262128
@@ -2462,7 +2464,7 @@ struct server_context {
24622464 continue ;
24632465 }
24642466
2465- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot. inf_type == SERVER_TASK_INF_TYPE_RERANK ) {
2467+ if (slot.is_non_causal () ) {
24662468 if (slot.n_prompt_tokens > n_ubatch) {
24672469 slot.release ();
24682470 send_error (slot, " input is too large to process. increase the physical batch size" , ERROR_TYPE_SERVER);
@@ -2577,18 +2579,15 @@ struct server_context {
25772579 }
25782580
25792581 // non-causal tasks require to fit the entire prompt in the physical batch
2580- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot. inf_type == SERVER_TASK_INF_TYPE_RERANK ) {
2582+ if (slot.is_non_causal () ) {
25812583 // cannot fit the prompt in the current batch - will try next iter
25822584 if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
25832585 continue ;
25842586 }
25852587 }
25862588
25872589 // check that we are in the right batch_type, if not defer the slot
2588- const bool slot_type =
2589- slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
2590- slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0 ;
2591-
2590+ int slot_type = slot.is_non_causal ();
25922591 if (batch_type == -1 ) {
25932592 batch_type = slot_type;
25942593 } else if (batch_type != slot_type) {
@@ -2705,15 +2704,15 @@ struct server_context {
27052704 }
27062705
27072706 if (slot.state == SLOT_STATE_DONE_PROMPT) {
2708- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ) {
2707+ if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING ) {
27092708 // prompt evaluated for embedding
27102709 send_embedding (slot, batch_view);
27112710 slot.release ();
27122711 slot.i_batch = -1 ;
27132712 continue ; // continue loop of slots
27142713 }
27152714
2716- if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ) {
2715+ if (slot.task_type == SERVER_TASK_TYPE_RERANK ) {
27172716 send_rerank (slot, batch_view);
27182717 slot.release ();
27192718 slot.i_batch = -1 ;
@@ -3352,11 +3351,13 @@ int main(int argc, char ** argv) {
33523351 // handle completion-like requests (completion, chat, infill)
33533352 // we can optionally provide a custom format for partial results and final results
33543353 const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
3355- server_task_inf_type inf_type ,
3354+ server_task_type type ,
33563355 json & data,
33573356 httplib::Response & res,
33583357 bool oaicompat = false ,
33593358 bool oaicompat_chat = false ) {
3359+ GGML_ASSERT (type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
3360+
33603361 if (ctx_server.params_base .embedding ) {
33613362 res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
33623363 return ;
@@ -3369,7 +3370,8 @@ int main(int argc, char ** argv) {
33693370 std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.ctx , data.at (" prompt" ), true , true );
33703371 tasks.reserve (tokenized_prompts.size ());
33713372 for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
3372- server_task task = server_task (SERVER_TASK_TYPE_INFERENCE, inf_type);
3373+ server_task task = server_task (type);
3374+
33733375 task.id = ctx_server.queue_tasks .get_new_id ();
33743376 task.index = i;
33753377
@@ -3450,7 +3452,7 @@ int main(int argc, char ** argv) {
34503452 const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
34513453 json data = json::parse (req.body );
34523454 return handle_completions_generic (
3453- SERVER_TASK_INF_TYPE_COMPLETION ,
3455+ SERVER_TASK_TYPE_COMPLETION ,
34543456 data,
34553457 res,
34563458 /* oaicompat */ false ,
@@ -3504,7 +3506,7 @@ int main(int argc, char ** argv) {
35043506 }
35053507 data[" input_extra" ] = input_extra; // default to empty array if it's not exist
35063508
3507- return handle_completions_generic (SERVER_TASK_INF_TYPE_INFILL , data, res);
3509+ return handle_completions_generic (SERVER_TASK_TYPE_INFILL , data, res);
35083510 };
35093511
35103512 const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@@ -3515,7 +3517,7 @@ int main(int argc, char ** argv) {
35153517
35163518 json data = oaicompat_completion_params_parse (ctx_server.model , json::parse (req.body ), params.chat_template );
35173519 return handle_completions_generic (
3518- SERVER_TASK_INF_TYPE_COMPLETION ,
3520+ SERVER_TASK_TYPE_COMPLETION ,
35193521 data,
35203522 res,
35213523 /* oaicompat */ true ,
@@ -3616,7 +3618,7 @@ int main(int argc, char ** argv) {
36163618 std::vector<server_task> tasks;
36173619 std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.ctx , prompt, /* add_special */ false , true );
36183620 for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
3619- server_task task = server_task (SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING );
3621+ server_task task = server_task (SERVER_TASK_TYPE_EMBEDDING );
36203622 task.id = ctx_server.queue_tasks .get_new_id ();
36213623 task.index = i;
36223624 task.prompt_tokens = std::move (tokenized_prompts[i]);
@@ -3698,7 +3700,7 @@ int main(int argc, char ** argv) {
36983700 std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server.ctx , documents, /* add_special */ false , true );
36993701 tasks.reserve (tokenized_docs.size ());
37003702 for (size_t i = 0 ; i < tokenized_docs.size (); i++) {
3701- server_task task = server_task (SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK );
3703+ server_task task = server_task (SERVER_TASK_TYPE_RERANK );
37023704 task.id = ctx_server.queue_tasks .get_new_id ();
37033705 task.index = i;
37043706 task.prompt_tokens = format_rerank (ctx_server.model , tokenized_query, tokenized_docs[i]);
0 commit comments