@@ -3211,87 +3211,88 @@ int main(int argc, char **argv)
32113211 res.set_content (models.dump (), " application/json; charset=utf-8" );
32123212 });
32133213
3214+ const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
3215+ {
3216+ res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3217+ if (!validate_api_key (req, res)) {
3218+ return ;
3219+ }
3220+ json data = oaicompat_completion_params_parse (llama.model , json::parse (req.body ), sparams.chat_template );
32143221
3215- // TODO: add mount point without "/v1" prefix -- how?
3216- svr.Post (" /v1/chat/completions" , [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
3217- {
3218- res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3219- if (!validate_api_key (req, res)) {
3220- return ;
3221- }
3222- json data = oaicompat_completion_params_parse (llama.model , json::parse (req.body ), sparams.chat_template );
3223-
3224- const int task_id = llama.queue_tasks .get_new_id ();
3225- llama.queue_results .add_waiting_task_id (task_id);
3226- llama.request_completion (task_id, data, false , false , -1 );
3222+ const int task_id = llama.queue_tasks .get_new_id ();
3223+ llama.queue_results .add_waiting_task_id (task_id);
3224+ llama.request_completion (task_id, data, false , false , -1 );
32273225
3228- if (!json_value (data, " stream" , false )) {
3229- std::string completion_text;
3230- task_result result = llama.queue_results .recv (task_id);
3226+ if (!json_value (data, " stream" , false )) {
3227+ std::string completion_text;
3228+ task_result result = llama.queue_results .recv (task_id);
32313229
3232- if (!result.error && result.stop ) {
3233- json oaicompat_result = format_final_response_oaicompat (data, result);
3230+ if (!result.error && result.stop ) {
3231+ json oaicompat_result = format_final_response_oaicompat (data, result);
32343232
3235- res.set_content (oaicompat_result.dump (-1 , ' ' , false ,
3236- json::error_handler_t ::replace),
3237- " application/json; charset=utf-8" );
3238- } else {
3239- res.status = 500 ;
3240- res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
3241- }
3242- llama.queue_results .remove_waiting_task_id (task_id);
3243- } else {
3244- const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink &sink) {
3245- while (true ) {
3246- task_result llama_result = llama.queue_results .recv (task_id);
3247- if (!llama_result.error ) {
3248- std::vector<json> result_array = format_partial_response_oaicompat ( llama_result);
3233+ res.set_content (oaicompat_result.dump (-1 , ' ' , false ,
3234+ json::error_handler_t ::replace),
3235+ " application/json; charset=utf-8" );
3236+ } else {
3237+ res.status = 500 ;
3238+ res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
3239+ }
3240+ llama.queue_results .remove_waiting_task_id (task_id);
3241+ } else {
3242+ const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink &sink) {
3243+ while (true ) {
3244+ task_result llama_result = llama.queue_results .recv (task_id);
3245+ if (!llama_result.error ) {
3246+ std::vector<json> result_array = format_partial_response_oaicompat ( llama_result);
32493247
3250- for (auto it = result_array.begin (); it != result_array.end (); ++it)
3251- {
3252- if (!it->empty ()) {
3253- const std::string str =
3254- " data: " +
3255- it->dump (-1 , ' ' , false , json::error_handler_t ::replace) +
3256- " \n\n " ;
3257- LOG_VERBOSE (" data stream" , {{" to_send" , str}});
3258- if (!sink.write (str.c_str (), str.size ())) {
3259- llama.queue_results .remove_waiting_task_id (task_id);
3260- return false ;
3261- }
3262- }
3263- }
3264- if (llama_result.stop ) {
3265- break ;
3266- }
3267- } else {
3248+ for (auto it = result_array.begin (); it != result_array.end (); ++it)
3249+ {
3250+ if (!it->empty ()) {
32683251 const std::string str =
3269- " error: " +
3270- llama_result.result_json .dump (-1 , ' ' , false ,
3271- json::error_handler_t ::replace) +
3252+ " data: " +
3253+ it->dump (-1 , ' ' , false , json::error_handler_t ::replace) +
32723254 " \n\n " ;
32733255 LOG_VERBOSE (" data stream" , {{" to_send" , str}});
32743256 if (!sink.write (str.c_str (), str.size ())) {
32753257 llama.queue_results .remove_waiting_task_id (task_id);
32763258 return false ;
32773259 }
3278- break ;
32793260 }
32803261 }
3281- sink.done ();
3282- llama.queue_results .remove_waiting_task_id (task_id);
3283- return true ;
3284- };
3262+ if (llama_result.stop ) {
3263+ break ;
3264+ }
3265+ } else {
3266+ const std::string str =
3267+ " error: " +
3268+ llama_result.result_json .dump (-1 , ' ' , false ,
3269+ json::error_handler_t ::replace) +
3270+ " \n\n " ;
3271+ LOG_VERBOSE (" data stream" , {{" to_send" , str}});
3272+ if (!sink.write (str.c_str (), str.size ())) {
3273+ llama.queue_results .remove_waiting_task_id (task_id);
3274+ return false ;
3275+ }
3276+ break ;
3277+ }
3278+ }
3279+ sink.done ();
3280+ llama.queue_results .remove_waiting_task_id (task_id);
3281+ return true ;
3282+ };
32853283
3286- auto on_complete = [task_id, &llama](bool ) {
3287- // cancel request
3288- llama.request_cancel (task_id);
3289- llama.queue_results .remove_waiting_task_id (task_id);
3290- };
3284+ auto on_complete = [task_id, &llama](bool ) {
3285+ // cancel request
3286+ llama.request_cancel (task_id);
3287+ llama.queue_results .remove_waiting_task_id (task_id);
3288+ };
32913289
3292- res.set_chunked_content_provider (" text/event-stream" , chunked_content_provider, on_complete);
3293- }
3294- });
3290+ res.set_chunked_content_provider (" text/event-stream" , chunked_content_provider, on_complete);
3291+ }
3292+ };
3293+
3294+ svr.Post (" /chat/completions" , chat_completions);
3295+ svr.Post (" /v1/chat/completions" , chat_completions);
32953296
32963297 svr.Post (" /infill" , [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
32973298 {
0 commit comments