Skip to content

Commit 90889fd

Browse files
committed
server : add OAI compat for /v1/completions
1 parent 9ba399d commit 90889fd

File tree

2 files changed

+198
-52
lines changed

2 files changed

+198
-52
lines changed

examples/server/server.cpp

Lines changed: 158 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ enum server_task_type {
6767
SERVER_TASK_TYPE_SET_LORA,
6868
};
6969

70+
enum oaicompat_type {
71+
OAICOMPAT_TYPE_NONE,
72+
OAICOMPAT_TYPE_CHAT,
73+
OAICOMPAT_TYPE_COMPLETION,
74+
OAICOMPAT_TYPE_EMBEDDING,
75+
};
76+
7077
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
7178
enum error_type {
7279
ERROR_TYPE_INVALID_REQUEST,
@@ -101,11 +108,10 @@ struct slot_params {
101108
struct common_params_speculative speculative;
102109

103110
// OAI-compat fields
104-
bool verbose = false;
105-
bool oaicompat = false;
106-
bool oaicompat_chat = true;
107-
std::string oaicompat_model;
108-
std::string oaicompat_cmpl_id;
111+
bool verbose = false;
112+
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
113+
std::string oaicompat_model;
114+
std::string oaicompat_cmpl_id;
109115

110116
json to_json() const {
111117
std::vector<std::string> samplers;
@@ -529,11 +535,10 @@ struct server_task_result_cmpl_final : server_task_result {
529535
slot_params generation_params;
530536

531537
// OAI-compat fields
532-
bool verbose = false;
533-
bool oaicompat = false;
534-
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
535-
std::string oaicompat_model;
536-
std::string oaicompat_cmpl_id;
538+
bool verbose = false;
539+
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
540+
std::string oaicompat_model;
541+
std::string oaicompat_cmpl_id;
537542

538543
virtual int get_index() override {
539544
return index;
@@ -544,9 +549,16 @@ struct server_task_result_cmpl_final : server_task_result {
544549
}
545550

546551
virtual json to_json() override {
547-
return oaicompat
548-
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
549-
: to_json_non_oaicompat();
552+
switch (oaicompat) {
553+
case OAICOMPAT_TYPE_NONE:
554+
return to_json_non_oaicompat();
555+
case OAICOMPAT_TYPE_COMPLETION:
556+
return to_json_oaicompat();
557+
case OAICOMPAT_TYPE_CHAT:
558+
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
559+
default:
560+
GGML_ASSERT(false && "Invalid oaicompat_type");
561+
}
550562
}
551563

552564
json to_json_non_oaicompat() {
@@ -574,6 +586,50 @@ struct server_task_result_cmpl_final : server_task_result {
574586
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
575587
}
576588

589+
json to_json_oaicompat() {
590+
std::time_t t = std::time(0);
591+
json logprobs = json(nullptr); // OAI default to null
592+
if (!stream && probs_output.size() > 0) {
593+
logprobs = json{
594+
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
595+
};
596+
}
597+
json finish_reason = "length";
598+
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
599+
finish_reason = "stop";
600+
}
601+
json res = json {
602+
{"choices", json::array({
603+
json{
604+
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
605+
{"index", index},
606+
{"logprobs", logprobs},
607+
{"finish_reason", finish_reason},
608+
}
609+
})},
610+
{"created", t},
611+
{"model", oaicompat_model},
612+
{"system_fingerprint", build_info},
613+
{"object", "text_completion"},
614+
{"usage", json {
615+
{"completion_tokens", n_decoded},
616+
{"prompt_tokens", n_prompt_tokens},
617+
{"total_tokens", n_decoded + n_prompt_tokens}
618+
}},
619+
{"id", oaicompat_cmpl_id}
620+
};
621+
622+
// extra fields for debugging purposes
623+
if (verbose) {
624+
res["__verbose"] = to_json_non_oaicompat();
625+
}
626+
if (timings.prompt_n >= 0) {
627+
res.push_back({"timings", timings.to_json()});
628+
}
629+
630+
return res;
631+
}
632+
577633
json to_json_oaicompat_chat() {
578634
std::string finish_reason = "length";
579635
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
@@ -671,11 +727,10 @@ struct server_task_result_cmpl_partial : server_task_result {
671727
result_timings timings;
672728

673729
// OAI-compat fields
674-
bool verbose = false;
675-
bool oaicompat = false;
676-
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
677-
std::string oaicompat_model;
678-
std::string oaicompat_cmpl_id;
730+
bool verbose = false;
731+
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
732+
std::string oaicompat_model;
733+
std::string oaicompat_cmpl_id;
679734

680735
virtual int get_index() override {
681736
return index;
@@ -686,7 +741,16 @@ struct server_task_result_cmpl_partial : server_task_result {
686741
}
687742

688743
virtual json to_json() override {
689-
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
744+
switch (oaicompat) {
745+
case OAICOMPAT_TYPE_NONE:
746+
return to_json_non_oaicompat();
747+
case OAICOMPAT_TYPE_COMPLETION:
748+
return to_json_oaicompat();
749+
case OAICOMPAT_TYPE_CHAT:
750+
return to_json_oaicompat_chat();
751+
default:
752+
GGML_ASSERT(false && "Invalid oaicompat_type");
753+
}
690754
}
691755

692756
json to_json_non_oaicompat() {
@@ -711,6 +775,41 @@ struct server_task_result_cmpl_partial : server_task_result {
711775
}
712776

713777
json to_json_oaicompat() {
778+
std::time_t t = std::time(0);
779+
json logprobs = json(nullptr); // OAI default to null
780+
if (prob_output.probs.size() > 0) {
781+
logprobs = json{
782+
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
783+
};
784+
}
785+
json res = json {
786+
{"choices", json::array({
787+
json{
788+
{"text", content},
789+
{"index", index},
790+
{"logprobs", logprobs},
791+
{"finish_reason", nullptr},
792+
}
793+
})},
794+
{"created", t},
795+
{"model", oaicompat_model},
796+
{"system_fingerprint", build_info},
797+
{"object", "text_completion"},
798+
{"id", oaicompat_cmpl_id}
799+
};
800+
801+
// extra fields for debugging purposes
802+
if (verbose) {
803+
res["__verbose"] = to_json_non_oaicompat();
804+
}
805+
if (timings.prompt_n >= 0) {
806+
res.push_back({"timings", timings.to_json()});
807+
}
808+
809+
return res;
810+
}
811+
812+
json to_json_oaicompat_chat() {
714813
bool first = n_decoded == 0;
715814
std::time_t t = std::time(0);
716815
json choices;
@@ -789,14 +888,16 @@ struct server_task_result_embd : server_task_result {
789888
int32_t n_tokens;
790889

791890
// OAI-compat fields
792-
bool oaicompat = false;
891+
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
793892

794893
virtual int get_index() override {
795894
return index;
796895
}
797896

798897
virtual json to_json() override {
799-
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
898+
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
899+
? to_json_oaicompat()
900+
: to_json_non_oaicompat();
800901
}
801902

802903
json to_json_non_oaicompat() {
@@ -2042,7 +2143,6 @@ struct server_context {
20422143

20432144
res->verbose = slot.params.verbose;
20442145
res->oaicompat = slot.params.oaicompat;
2045-
res->oaicompat_chat = slot.params.oaicompat_chat;
20462146
res->oaicompat_model = slot.params.oaicompat_model;
20472147
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
20482148

@@ -2083,7 +2183,6 @@ struct server_context {
20832183
res->verbose = slot.params.verbose;
20842184
res->stream = slot.params.stream;
20852185
res->oaicompat = slot.params.oaicompat;
2086-
res->oaicompat_chat = slot.params.oaicompat_chat;
20872186
res->oaicompat_model = slot.params.oaicompat_model;
20882187
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
20892188

@@ -3504,12 +3603,11 @@ int main(int argc, char ** argv) {
35043603

35053604
// handle completion-like requests (completion, chat, infill)
35063605
// we can optionally provide a custom format for partial results and final results
3507-
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
3606+
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
35083607
server_task_type type,
35093608
json & data,
35103609
httplib::Response & res,
3511-
bool oaicompat = false,
3512-
bool oaicompat_chat = false) {
3610+
oaicompat_type oaicompat) {
35133611
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
35143612

35153613
if (ctx_server.params_base.embedding) {
@@ -3534,9 +3632,8 @@ int main(int argc, char ** argv) {
35343632
task.id_selected_slot = json_value(data, "id_slot", -1);
35353633

35363634
// OAI-compat
3537-
task.params.oaicompat = oaicompat;
3538-
task.params.oaicompat_chat = oaicompat_chat;
3539-
task.params.oaicompat_cmpl_id = completion_id;
3635+
task.params.oaicompat = oaicompat;
3636+
task.params.oaicompat_cmpl_id = completion_id;
35403637
// oaicompat_model is already populated by params_from_json_cmpl
35413638

35423639
tasks.push_back(task);
@@ -3587,7 +3684,7 @@ int main(int argc, char ** argv) {
35873684
}, [&](const json & error_data) {
35883685
server_sent_event(sink, "error", error_data);
35893686
});
3590-
if (oaicompat) {
3687+
if (oaicompat != OAICOMPAT_TYPE_NONE) {
35913688
static const std::string ev_done = "data: [DONE]\n\n";
35923689
sink.write(ev_done.data(), ev_done.size());
35933690
}
@@ -3603,17 +3700,25 @@ int main(int argc, char ** argv) {
36033700
}
36043701
};
36053702

3606-
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3703+
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
36073704
json data = json::parse(req.body);
3608-
return handle_completions_generic(
3705+
return handle_completions_impl(
3706+
SERVER_TASK_TYPE_COMPLETION,
3707+
data,
3708+
res,
3709+
OAICOMPAT_TYPE_NONE);
3710+
};
3711+
3712+
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
3713+
json data = oaicompat_completion_params_parse(json::parse(req.body));
3714+
return handle_completions_impl(
36093715
SERVER_TASK_TYPE_COMPLETION,
36103716
data,
36113717
res,
3612-
/* oaicompat */ false,
3613-
/* oaicompat_chat */ false);
3718+
OAICOMPAT_TYPE_COMPLETION);
36143719
};
36153720

3616-
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3721+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
36173722
// check model compatibility
36183723
std::string err;
36193724
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
@@ -3682,22 +3787,25 @@ int main(int argc, char ** argv) {
36823787
tokenized_prompts[0]
36833788
);
36843789

3685-
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
3790+
return handle_completions_impl(
3791+
SERVER_TASK_TYPE_INFILL,
3792+
data,
3793+
res,
3794+
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
36863795
};
36873796

3688-
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3797+
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
36893798
if (ctx_server.params_base.embedding) {
36903799
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
36913800
return;
36923801
}
36933802

3694-
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3695-
return handle_completions_generic(
3803+
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3804+
return handle_completions_impl(
36963805
SERVER_TASK_TYPE_COMPLETION,
36973806
data,
36983807
res,
3699-
/* oaicompat */ true,
3700-
/* oaicompat_chat */ true);
3808+
OAICOMPAT_TYPE_CHAT);
37013809
};
37023810

37033811
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
@@ -3770,10 +3878,10 @@ int main(int argc, char ** argv) {
37703878
res_ok(res, data);
37713879
};
37723880

3773-
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
3881+
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
37743882
const json body = json::parse(req.body);
37753883

3776-
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
3884+
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
37773885
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
37783886
return;
37793887
}
@@ -3783,7 +3891,7 @@ int main(int argc, char ** argv) {
37833891
if (body.count("input") != 0) {
37843892
prompt = body.at("input");
37853893
} else if (body.contains("content")) {
3786-
oaicompat = false;
3894+
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
37873895
prompt = body.at("content");
37883896
} else {
37893897
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
@@ -3852,16 +3960,18 @@ int main(int argc, char ** argv) {
38523960
}
38533961

38543962
// write JSON response
3855-
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
3963+
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
3964+
? format_embeddings_response_oaicompat(body, responses, use_base64)
3965+
: json(responses);
38563966
res_ok(res, root);
38573967
};
38583968

38593969
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3860-
handle_embeddings_impl(req, res, false);
3970+
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
38613971
};
38623972

38633973
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3864-
handle_embeddings_impl(req, res, true);
3974+
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
38653975
};
38663976

38673977
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
@@ -4031,7 +4141,7 @@ int main(int argc, char ** argv) {
40314141
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
40324142
svr->Post("/completion", handle_completions); // legacy
40334143
svr->Post("/completions", handle_completions);
4034-
svr->Post("/v1/completions", handle_completions);
4144+
svr->Post("/v1/completions", handle_completions_oai);
40354145
svr->Post("/chat/completions", handle_chat_completions);
40364146
svr->Post("/v1/chat/completions", handle_chat_completions);
40374147
svr->Post("/infill", handle_infill);

0 commit comments

Comments
 (0)