@@ -67,6 +67,13 @@ enum server_task_type {
6767 SERVER_TASK_TYPE_SET_LORA,
6868};
6969
70+ enum oaicompat_type {
71+ OAICOMPAT_TYPE_NONE,
72+ OAICOMPAT_TYPE_CHAT,
73+ OAICOMPAT_TYPE_COMPLETION,
74+ OAICOMPAT_TYPE_EMBEDDING,
75+ };
76+
7077// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
7178enum error_type {
7279 ERROR_TYPE_INVALID_REQUEST,
@@ -101,11 +108,10 @@ struct slot_params {
101108 struct common_params_speculative speculative;
102109
103110 // OAI-compat fields
104- bool verbose = false ;
105- bool oaicompat = false ;
106- bool oaicompat_chat = true ;
107- std::string oaicompat_model;
108- std::string oaicompat_cmpl_id;
111+ bool verbose = false ;
112+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
113+ std::string oaicompat_model;
114+ std::string oaicompat_cmpl_id;
109115
110116 json to_json () const {
111117 std::vector<std::string> samplers;
@@ -529,11 +535,10 @@ struct server_task_result_cmpl_final : server_task_result {
529535 slot_params generation_params;
530536
531537 // OAI-compat fields
532- bool verbose = false ;
533- bool oaicompat = false ;
534- bool oaicompat_chat = true ; // TODO: support oaicompat for non-chat
535- std::string oaicompat_model;
536- std::string oaicompat_cmpl_id;
538+ bool verbose = false ;
539+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
540+ std::string oaicompat_model;
541+ std::string oaicompat_cmpl_id;
537542
538543 virtual int get_index () override {
539544 return index;
@@ -544,9 +549,16 @@ struct server_task_result_cmpl_final : server_task_result {
544549 }
545550
546551 virtual json to_json () override {
547- return oaicompat
548- ? (stream ? to_json_oaicompat_chat_stream () : to_json_oaicompat_chat ())
549- : to_json_non_oaicompat ();
552+ switch (oaicompat) {
553+ case OAICOMPAT_TYPE_NONE:
554+ return to_json_non_oaicompat ();
555+ case OAICOMPAT_TYPE_COMPLETION:
556+ return to_json_oaicompat ();
557+ case OAICOMPAT_TYPE_CHAT:
558+ return stream ? to_json_oaicompat_chat_stream () : to_json_oaicompat_chat ();
559+ default :
560+ GGML_ASSERT (false && " Invalid oaicompat_type" );
561+ }
550562 }
551563
552564 json to_json_non_oaicompat () {
@@ -574,6 +586,50 @@ struct server_task_result_cmpl_final : server_task_result {
574586 return response_fields.empty () ? res : json_get_nested_values (response_fields, res);
575587 }
576588
589+ json to_json_oaicompat () {
590+ std::time_t t = std::time (0 );
591+ json logprobs = json (nullptr ); // OAI default to null
592+ if (!stream && probs_output.size () > 0 ) {
593+ logprobs = json{
594+ {" content" , completion_token_output::probs_vector_to_json (probs_output, post_sampling_probs)},
595+ };
596+ }
597+ json finish_reason = " length" ;
598+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
599+ finish_reason = " stop" ;
600+ }
601+ json res = json {
602+ {" choices" , json::array ({
603+ json{
604+ {" text" , stream ? " " : content}, // in stream mode, content is already in last partial chunk
605+ {" index" , index},
606+ {" logprobs" , logprobs},
607+ {" finish_reason" , finish_reason},
608+ }
609+ })},
610+ {" created" , t},
611+ {" model" , oaicompat_model},
612+ {" system_fingerprint" , build_info},
613+ {" object" , " text_completion" },
614+ {" usage" , json {
615+ {" completion_tokens" , n_decoded},
616+ {" prompt_tokens" , n_prompt_tokens},
617+ {" total_tokens" , n_decoded + n_prompt_tokens}
618+ }},
619+ {" id" , oaicompat_cmpl_id}
620+ };
621+
622+ // extra fields for debugging purposes
623+ if (verbose) {
624+ res[" __verbose" ] = to_json_non_oaicompat ();
625+ }
626+ if (timings.prompt_n >= 0 ) {
627+ res.push_back ({" timings" , timings.to_json ()});
628+ }
629+
630+ return res;
631+ }
632+
577633 json to_json_oaicompat_chat () {
578634 std::string finish_reason = " length" ;
579635 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
@@ -671,11 +727,10 @@ struct server_task_result_cmpl_partial : server_task_result {
671727 result_timings timings;
672728
673729 // OAI-compat fields
674- bool verbose = false ;
675- bool oaicompat = false ;
676- bool oaicompat_chat = true ; // TODO: support oaicompat for non-chat
677- std::string oaicompat_model;
678- std::string oaicompat_cmpl_id;
730+ bool verbose = false ;
731+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
732+ std::string oaicompat_model;
733+ std::string oaicompat_cmpl_id;
679734
680735 virtual int get_index () override {
681736 return index;
@@ -686,7 +741,16 @@ struct server_task_result_cmpl_partial : server_task_result {
686741 }
687742
688743 virtual json to_json () override {
689- return oaicompat ? to_json_oaicompat () : to_json_non_oaicompat ();
744+ switch (oaicompat) {
745+ case OAICOMPAT_TYPE_NONE:
746+ return to_json_non_oaicompat ();
747+ case OAICOMPAT_TYPE_COMPLETION:
748+ return to_json_oaicompat ();
749+ case OAICOMPAT_TYPE_CHAT:
750+ return to_json_oaicompat_chat ();
751+ default :
752+ GGML_ASSERT (false && " Invalid oaicompat_type" );
753+ }
690754 }
691755
692756 json to_json_non_oaicompat () {
@@ -711,6 +775,41 @@ struct server_task_result_cmpl_partial : server_task_result {
711775 }
712776
713777 json to_json_oaicompat () {
778+ std::time_t t = std::time (0 );
779+ json logprobs = json (nullptr ); // OAI default to null
780+ if (prob_output.probs .size () > 0 ) {
781+ logprobs = json{
782+ {" content" , completion_token_output::probs_vector_to_json ({prob_output}, post_sampling_probs)},
783+ };
784+ }
785+ json res = json {
786+ {" choices" , json::array ({
787+ json{
788+ {" text" , content},
789+ {" index" , index},
790+ {" logprobs" , logprobs},
791+ {" finish_reason" , nullptr },
792+ }
793+ })},
794+ {" created" , t},
795+ {" model" , oaicompat_model},
796+ {" system_fingerprint" , build_info},
797+ {" object" , " text_completion" },
798+ {" id" , oaicompat_cmpl_id}
799+ };
800+
801+ // extra fields for debugging purposes
802+ if (verbose) {
803+ res[" __verbose" ] = to_json_non_oaicompat ();
804+ }
805+ if (timings.prompt_n >= 0 ) {
806+ res.push_back ({" timings" , timings.to_json ()});
807+ }
808+
809+ return res;
810+ }
811+
812+ json to_json_oaicompat_chat () {
714813 bool first = n_decoded == 0 ;
715814 std::time_t t = std::time (0 );
716815 json choices;
@@ -789,14 +888,16 @@ struct server_task_result_embd : server_task_result {
789888 int32_t n_tokens;
790889
791890 // OAI-compat fields
792- bool oaicompat = false ;
891+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE ;
793892
794893 virtual int get_index () override {
795894 return index;
796895 }
797896
798897 virtual json to_json () override {
799- return oaicompat ? to_json_oaicompat () : to_json_non_oaicompat ();
898+ return oaicompat == OAICOMPAT_TYPE_EMBEDDING
899+ ? to_json_oaicompat ()
900+ : to_json_non_oaicompat ();
800901 }
801902
802903 json to_json_non_oaicompat () {
@@ -2042,7 +2143,6 @@ struct server_context {
20422143
20432144 res->verbose = slot.params .verbose ;
20442145 res->oaicompat = slot.params .oaicompat ;
2045- res->oaicompat_chat = slot.params .oaicompat_chat ;
20462146 res->oaicompat_model = slot.params .oaicompat_model ;
20472147 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
20482148
@@ -2083,7 +2183,6 @@ struct server_context {
20832183 res->verbose = slot.params .verbose ;
20842184 res->stream = slot.params .stream ;
20852185 res->oaicompat = slot.params .oaicompat ;
2086- res->oaicompat_chat = slot.params .oaicompat_chat ;
20872186 res->oaicompat_model = slot.params .oaicompat_model ;
20882187 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
20892188
@@ -3504,12 +3603,11 @@ int main(int argc, char ** argv) {
35043603
35053604 // handle completion-like requests (completion, chat, infill)
35063605 // we can optionally provide a custom format for partial results and final results
3507- const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
3606+ const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
35083607 server_task_type type,
35093608 json & data,
35103609 httplib::Response & res,
3511- bool oaicompat = false ,
3512- bool oaicompat_chat = false ) {
3610+ oaicompat_type oaicompat) {
35133611 GGML_ASSERT (type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
35143612
35153613 if (ctx_server.params_base .embedding ) {
@@ -3534,9 +3632,8 @@ int main(int argc, char ** argv) {
35343632 task.id_selected_slot = json_value (data, " id_slot" , -1 );
35353633
35363634 // OAI-compat
3537- task.params .oaicompat = oaicompat;
3538- task.params .oaicompat_chat = oaicompat_chat;
3539- task.params .oaicompat_cmpl_id = completion_id;
3635+ task.params .oaicompat = oaicompat;
3636+ task.params .oaicompat_cmpl_id = completion_id;
35403637 // oaicompat_model is already populated by params_from_json_cmpl
35413638
35423639 tasks.push_back (task);
@@ -3587,7 +3684,7 @@ int main(int argc, char ** argv) {
35873684 }, [&](const json & error_data) {
35883685 server_sent_event (sink, " error" , error_data);
35893686 });
3590- if (oaicompat) {
3687+ if (oaicompat != OAICOMPAT_TYPE_NONE ) {
35913688 static const std::string ev_done = " data: [DONE]\n\n " ;
35923689 sink.write (ev_done.data (), ev_done.size ());
35933690 }
@@ -3603,17 +3700,25 @@ int main(int argc, char ** argv) {
36033700 }
36043701 };
36053702
3606- const auto handle_completions = [&handle_completions_generic ](const httplib::Request & req, httplib::Response & res) {
3703+ const auto handle_completions = [&handle_completions_impl ](const httplib::Request & req, httplib::Response & res) {
36073704 json data = json::parse (req.body );
3608- return handle_completions_generic (
3705+ return handle_completions_impl (
3706+ SERVER_TASK_TYPE_COMPLETION,
3707+ data,
3708+ res,
3709+ OAICOMPAT_TYPE_NONE);
3710+ };
3711+
3712+ const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
3713+ json data = oaicompat_completion_params_parse (json::parse (req.body ));
3714+ return handle_completions_impl (
36093715 SERVER_TASK_TYPE_COMPLETION,
36103716 data,
36113717 res,
3612- /* oaicompat */ false ,
3613- /* oaicompat_chat */ false );
3718+ OAICOMPAT_TYPE_COMPLETION);
36143719 };
36153720
3616- const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic ](const httplib::Request & req, httplib::Response & res) {
3721+ const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl ](const httplib::Request & req, httplib::Response & res) {
36173722 // check model compatibility
36183723 std::string err;
36193724 if (llama_token_fim_pre (ctx_server.model ) == LLAMA_TOKEN_NULL) {
@@ -3682,22 +3787,25 @@ int main(int argc, char ** argv) {
36823787 tokenized_prompts[0 ]
36833788 );
36843789
3685- return handle_completions_generic (SERVER_TASK_TYPE_INFILL, data, res);
3790+ return handle_completions_impl (
3791+ SERVER_TASK_TYPE_INFILL,
3792+ data,
3793+ res,
3794+ OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
36863795 };
36873796
3688- const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic ](const httplib::Request & req, httplib::Response & res) {
3797+ const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl ](const httplib::Request & req, httplib::Response & res) {
36893798 if (ctx_server.params_base .embedding ) {
36903799 res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
36913800 return ;
36923801 }
36933802
3694- json data = oaicompat_completion_params_parse (ctx_server.model , json::parse (req.body ), params.chat_template );
3695- return handle_completions_generic (
3803+ json data = oaicompat_chat_completion_params_parse (ctx_server.model , json::parse (req.body ), params.chat_template );
3804+ return handle_completions_impl (
36963805 SERVER_TASK_TYPE_COMPLETION,
36973806 data,
36983807 res,
3699- /* oaicompat */ true ,
3700- /* oaicompat_chat */ true );
3808+ OAICOMPAT_TYPE_CHAT);
37013809 };
37023810
37033811 const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
@@ -3770,10 +3878,10 @@ int main(int argc, char ** argv) {
37703878 res_ok (res, data);
37713879 };
37723880
3773- const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
3881+ const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
37743882 const json body = json::parse (req.body );
37753883
3776- if (oaicompat && llama_pooling_type (ctx_server.ctx ) == LLAMA_POOLING_TYPE_NONE) {
3884+ if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type (ctx_server.ctx ) == LLAMA_POOLING_TYPE_NONE) {
37773885 res_error (res, format_error_response (" Pooling type 'none' is not OAI compatible. Please use a different pooling type" , ERROR_TYPE_INVALID_REQUEST));
37783886 return ;
37793887 }
@@ -3783,7 +3891,7 @@ int main(int argc, char ** argv) {
37833891 if (body.count (" input" ) != 0 ) {
37843892 prompt = body.at (" input" );
37853893 } else if (body.contains (" content" )) {
3786- oaicompat = false ;
3894+ oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
37873895 prompt = body.at (" content" );
37883896 } else {
37893897 res_error (res, format_error_response (" \" input\" or \" content\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
@@ -3852,16 +3960,18 @@ int main(int argc, char ** argv) {
38523960 }
38533961
38543962 // write JSON response
3855- json root = oaicompat ? format_embeddings_response_oaicompat (body, responses, use_base64) : json (responses);
3963+ json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
3964+ ? format_embeddings_response_oaicompat (body, responses, use_base64)
3965+ : json (responses);
38563966 res_ok (res, root);
38573967 };
38583968
38593969 const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3860- handle_embeddings_impl (req, res, false );
3970+ handle_embeddings_impl (req, res, OAICOMPAT_TYPE_NONE );
38613971 };
38623972
38633973 const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3864- handle_embeddings_impl (req, res, true );
3974+ handle_embeddings_impl (req, res, OAICOMPAT_TYPE_EMBEDDING );
38653975 };
38663976
38673977 const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
@@ -4031,7 +4141,7 @@ int main(int argc, char ** argv) {
40314141 svr->Get (" /v1/models" , handle_models); // public endpoint (no API key check)
40324142 svr->Post (" /completion" , handle_completions); // legacy
40334143 svr->Post (" /completions" , handle_completions);
4034- svr->Post (" /v1/completions" , handle_completions );
4144+ svr->Post (" /v1/completions" , handle_completions_oai );
40354145 svr->Post (" /chat/completions" , handle_chat_completions);
40364146 svr->Post (" /v1/chat/completions" , handle_chat_completions);
40374147 svr->Post (" /infill" , handle_infill);
0 commit comments