diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 9f0b0ffaa6e1e..776f466719dac 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -77,6 +77,7 @@ enum oaicompat_type { OAICOMPAT_TYPE_CHAT, OAICOMPAT_TYPE_COMPLETION, OAICOMPAT_TYPE_EMBEDDING, + OAICOMPAT_TYPE_API_CHAT }; // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 @@ -676,6 +677,8 @@ struct server_task_result_cmpl_final : server_task_result { return to_json_oaicompat(); case OAICOMPAT_TYPE_CHAT: return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + case OAICOMPAT_TYPE_API_CHAT: + return to_json_oaicompat_api_chat(); default: GGML_ASSERT(false && "Invalid oaicompat_type"); } @@ -858,6 +861,55 @@ struct server_task_result_cmpl_final : server_task_result { return deltas; } + + json to_json_oaicompat_api_chat() { + // Ollama final response format (streaming or non-streaming) + std::time_t t = std::time(0); + std::string finish_reason = "none"; // default value + if (stop == STOP_TYPE_EOS || stop == STOP_TYPE_WORD) { + // Ollama uses "stop" for both EOS and word stops + finish_reason = "stop"; + } else if (stop == STOP_TYPE_LIMIT) { + // Ollama uses "length" for limit stops + finish_reason = "length"; + } + + uint64_t prompt_ns = static_cast(timings.prompt_ms) * 1e6; // ms to ns + uint64_t predicted_ns = static_cast(timings.predicted_ms) * 1e6; // ms to ns + + json res = { + { "model", oaicompat_model }, + { "created_at", t }, + { "message", + { + { "role", "assistant" }, + { "content", stream ? "" : content } // content is empty in final streaming chunk + } }, + { "done_reason", finish_reason }, + { "done", true }, + // Add metrics from timings and other fields, converted to nanoseconds + { "total_duration", prompt_ns + predicted_ns }, + { "load_duration", prompt_ns }, // Assuming load duration is prompt eval time + { "prompt_eval_count", n_prompt_tokens }, + { "prompt_eval_duration", prompt_ns }, + { "eval_count", n_decoded }, + { "eval_duration", predicted_ns }, + { "prompt_tokens", n_prompt_tokens }, + { "completion_tokens", n_decoded }, + { "total_tokens", n_prompt_tokens + n_decoded }, + { "id_slot", id_slot }, + { "id", oaicompat_cmpl_id }, + { "system_fingerprint", build_info }, + { "object", "chat.completion" }, + }; + + // Ollama non-streaming includes the full content in the final response + if (!stream) { + res["message"]["content"] = content; + } + + return res; + } }; struct server_task_result_cmpl_partial : server_task_result { @@ -896,6 +948,8 @@ struct server_task_result_cmpl_partial : server_task_result { return to_json_oaicompat(); case OAICOMPAT_TYPE_CHAT: return to_json_oaicompat_chat(); + case OAICOMPAT_TYPE_API_CHAT: + return to_json_oaicompat_api_chat(); default: GGML_ASSERT(false && "Invalid oaicompat_type"); } @@ -1007,6 +1061,24 @@ struct server_task_result_cmpl_partial : server_task_result { return deltas; } + + json to_json_oaicompat_api_chat() { + std::time_t t = std::time(0); + { + // Ollama streaming partial response format + json res = { + { "model", oaicompat_model }, + { "created_at", t }, + { "message", + { + { "role", "assistant" }, { "content", content } // partial content + } }, + { "done", false } + }; + // Ollama streaming responses don't seem to include timings or logprobs per partial token + return res; + } + } }; struct server_task_result_embd : server_task_result { @@ -4294,14 +4366,27 @@ int main(int argc, char ** argv) { json res_json = result->to_json(); if (res_json.is_array()) { for (const auto & res : res_json) { - if (!server_sent_event(sink, "data", res)) { - // sending failed (HTTP connection closed), cancel the generation - return false; + // ollama's /api/chat does not conform to the SEE format + if (oaicompat == OAICOMPAT_TYPE_API_CHAT) { + std::string s = safe_json_to_str(res) + "\n"; + if (!sink.write(s.data(), s.size())) { + return false; + } + } else { + if (!server_sent_event(sink, "data", res)) { + // sending failed (HTTP connection closed), cancel the generation + return false; + } } } return true; } else { - return server_sent_event(sink, "data", res_json); + if (oaicompat == OAICOMPAT_TYPE_API_CHAT) { + std::string s = safe_json_to_str(res_json) + "\n"; + return sink.write(s.data(), s.size()); + } else { + return server_sent_event(sink, "data", res_json); + } } }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); @@ -4309,7 +4394,7 @@ int main(int argc, char ** argv) { // note: do not use req.is_connection_closed here because req is already destroyed return !sink.is_writable(); }); - if (oaicompat != OAICOMPAT_TYPE_NONE) { + if (oaicompat != OAICOMPAT_TYPE_NONE && oaicompat != OAICOMPAT_TYPE_API_CHAT) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); } @@ -4436,6 +4521,16 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); + // Ollama chat endpoint specific handling + auto OAICOMPAT_TYPE = OAICOMPAT_TYPE_CHAT; + if (req.path == "/api/chat") { + OAICOMPAT_TYPE = OAICOMPAT_TYPE_API_CHAT; + + // Set default stream to true for /api/chat + if (!body.contains("stream")) { + body["stream"] = true; + } + } std::vector files; json data = oaicompat_chat_params_parse( body, @@ -4448,7 +4543,7 @@ int main(int argc, char ** argv) { files, req.is_connection_closed, res, - OAICOMPAT_TYPE_CHAT); + OAICOMPAT_TYPE); }; // same with handle_chat_completions, but without inference part