@@ -128,10 +128,11 @@ struct slot_params {
128128 bool can_speculative;
129129
130130 // OAI-compat fields
131- bool oaicompat = false ;
131+ bool verbose = false ;
132+ bool oaicompat = false ;
133+ bool oaicompat_chat = true ;
132134 std::string oaicompat_model;
133135 std::string oaicompat_cmpl_id;
134- bool verbose = false ;
135136
136137 json to_json () {
137138 std::vector<std::string> samplers;
@@ -226,10 +227,6 @@ struct server_task_result {
226227 return -1 ;
227228 }
228229 virtual json to_json () = 0;
229- virtual json to_json_oai_compat () {
230- // used by server_task_result_cmpl_final and server_task_result_cmpl_partial
231- return json ();
232- }
233230 virtual ~server_task_result () = default ;
234231};
235232
@@ -299,16 +296,21 @@ struct server_task_result_cmpl_final : server_task_result {
299296 slot_params generation_params;
300297
301298 // OAI-compat fields
299+ bool verbose = false ;
300+ bool oaicompat = false ;
301+ bool oaicompat_chat = true ; // TODO: support oaicompat for non-chat
302302 std::string oaicompat_model;
303303 std::string oaicompat_cmpl_id;
304- bool verbose = false ;
305304
306305 virtual int get_index () override {
307306 return index;
308307 }
309308
310309 virtual json to_json () override {
311- // non-OAI-compat JSON
310+ if (oaicompat) {
311+ return to_json_oai_compat ();
312+ }
313+ // otherwise, non-OAI-compat JSON
312314 json res = json {
313315 {" index" , index},
314316 {" content" , content},
@@ -332,7 +334,7 @@ struct server_task_result_cmpl_final : server_task_result {
332334 return res;
333335 }
334336
335- virtual json to_json_oai_compat () override {
337+ json to_json_oai_compat () {
336338 std::string finish_reason = " length" ;
337339 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
338340 finish_reason = " stop" ;
@@ -388,9 +390,11 @@ struct server_task_result_cmpl_partial : server_task_result {
388390 result_timings timings;
389391
390392 // OAI-compat fields
393+ bool verbose = false ;
394+ bool oaicompat = false ;
395+ bool oaicompat_chat = true ; // TODO: support oaicompat for non-chat
391396 std::string oaicompat_model;
392397 std::string oaicompat_cmpl_id;
393- bool verbose = false ;
394398
395399 virtual int get_index () override {
396400 return index;
@@ -401,6 +405,9 @@ struct server_task_result_cmpl_partial : server_task_result {
401405 }
402406
403407 virtual json to_json () override {
408+ if (oaicompat) {
409+ return to_json_oai_compat ();
410+ }
404411 bool is_stop = stop != STOP_TYPE_NONE;
405412 // non-OAI-compat JSON
406413 json res = json {
@@ -425,7 +432,7 @@ struct server_task_result_cmpl_partial : server_task_result {
425432 return res;
426433 }
427434
428- virtual json to_json_oai_compat () override {
435+ json to_json_oai_compat () {
429436 bool first = n_decoded == 0 ;
430437
431438 std::string finish_reason;
@@ -1461,6 +1468,7 @@ struct server_context {
14611468 if (data.count (" __oaicompat" ) != 0 ) {
14621469 std::string model_name = params_base.model_alias .empty () ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias ;
14631470 slot.params .oaicompat = true ;
1471+ slot.params .oaicompat_chat = json_value (data, " __oaicompat_chat" , false );
14641472 slot.params .oaicompat_model = json_value (data, " model" , model_name);
14651473 slot.params .oaicompat_cmpl_id = json_value (data, " completion_id" , std::string ());
14661474 } else {
@@ -1850,9 +1858,11 @@ struct server_context {
18501858
18511859 res->stop = slot.stop ;
18521860
1861+ res->verbose = slot.params .verbose ;
1862+ res->oaicompat = slot.params .oaicompat ;
1863+ res->oaicompat_chat = slot.params .oaicompat_chat ;
18531864 res->oaicompat_model = slot.params .oaicompat_model ;
18541865 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
1855- res->verbose = slot.params .verbose ;
18561866
18571867 // populate res.probs_output
18581868 if (slot.params .sampling .n_probs > 0 ) {
@@ -1899,9 +1909,11 @@ struct server_context {
18991909 res->stopping_word = slot.stopping_word ;
19001910 res->stop = slot.stop ;
19011911
1912+ res->verbose = slot.params .verbose ;
1913+ res->oaicompat = slot.params .oaicompat ;
1914+ res->oaicompat_chat = slot.params .oaicompat_chat ;
19021915 res->oaicompat_model = slot.params .oaicompat_model ;
19031916 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
1904- res->verbose = slot.params .verbose ;
19051917
19061918 // populate res.probs_output
19071919 if (slot.params .sampling .n_probs > 0 ) {
@@ -3397,12 +3409,12 @@ int main(int argc, char ** argv) {
33973409 ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
33983410 if (results.size () == 1 ) {
33993411 // single result
3400- res_ok (res, oai_compat ? results[ 0 ]-> to_json_oai_compat () : results[0 ]->to_json ());
3412+ res_ok (res, results[0 ]->to_json ());
34013413 } else {
34023414 // multiple results (multitask)
34033415 json arr = json::array ();
34043416 for (auto & res : results) {
3405- arr.push_back (oai_compat ? res-> to_json_oai_compat () : res->to_json ());
3417+ arr.push_back (res->to_json ());
34063418 }
34073419 res_ok (res, arr);
34083420 }
@@ -3414,7 +3426,7 @@ int main(int argc, char ** argv) {
34143426 } else {
34153427 const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t , httplib::DataSink & sink) {
34163428 ctx_server.receive_cmpl_results_stream (task_ids, [&](server_task_result_ptr & result) -> bool {
3417- json res_json = oai_compat ? result-> to_json_oai_compat () : result->to_json ();
3429+ json res_json = result->to_json ();
34183430 if (res_json.is_array ()) {
34193431 for (const auto & res : res_json) {
34203432 if (!server_sent_event (sink, " data" , res)) {
@@ -3506,7 +3518,7 @@ int main(int argc, char ** argv) {
35063518 }
35073519
35083520 json data = oaicompat_completion_params_parse (ctx_server.model , json::parse (req.body ), params.chat_template );
3509-
3521+ data[ " __oaicompat_chat " ] = true ;
35103522 return handle_completions_generic (SERVER_TASK_INF_TYPE_COMPLETION, data, res, true );
35113523 };
35123524
0 commit comments