Skip to content

Commit 24d263d

Browse files
authored
Add proper implementation of ollama's /api/chat
1 parent de2ef53 commit 24d263d

File tree

1 file changed

+101
-6
lines changed

1 file changed

+101
-6
lines changed

tools/server/server.cpp

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ enum oaicompat_type {
7777
OAICOMPAT_TYPE_CHAT,
7878
OAICOMPAT_TYPE_COMPLETION,
7979
OAICOMPAT_TYPE_EMBEDDING,
80+
OAICOMPAT_TYPE_API_CHAT
8081
};
8182

8283
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
@@ -676,6 +677,8 @@ struct server_task_result_cmpl_final : server_task_result {
676677
return to_json_oaicompat();
677678
case OAICOMPAT_TYPE_CHAT:
678679
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
680+
case OAICOMPAT_TYPE_API_CHAT:
681+
return to_json_oaicompat_api_chat();
679682
default:
680683
GGML_ASSERT(false && "Invalid oaicompat_type");
681684
}
@@ -858,6 +861,55 @@ struct server_task_result_cmpl_final : server_task_result {
858861

859862
return deltas;
860863
}
864+
865+
json to_json_oaicompat_api_chat() {
866+
// Ollama final response format (streaming or non-streaming)
867+
std::time_t t = std::time(0);
868+
std::string finish_reason = "none"; // default value
869+
if (stop == STOP_TYPE_EOS || stop == STOP_TYPE_WORD) {
870+
// Ollama uses "stop" for both EOS and word stops
871+
finish_reason = "stop";
872+
} else if (stop == STOP_TYPE_LIMIT) {
873+
// Ollama uses "length" for limit stops
874+
finish_reason = "length";
875+
}
876+
877+
uint64_t prompt_ns = static_cast<uint64_t>(timings.prompt_ms) * 1e6; // ms to ns
878+
uint64_t predicted_ns = static_cast<uint64_t>(timings.predicted_ms) * 1e6; // ms to ns
879+
880+
json res = {
881+
{ "model", oaicompat_model },
882+
{ "created_at", t },
883+
{ "message",
884+
{
885+
{ "role", "assistant" },
886+
{ "content", stream ? "" : content } // content is empty in final streaming chunk
887+
} },
888+
{ "done_reason", finish_reason },
889+
{ "done", true },
890+
// Add metrics from timings and other fields, converted to nanoseconds
891+
{ "total_duration", prompt_ns + predicted_ns },
892+
{ "load_duration", prompt_ns }, // Assuming load duration is prompt eval time
893+
{ "prompt_eval_count", n_prompt_tokens },
894+
{ "prompt_eval_duration", prompt_ns },
895+
{ "eval_count", n_decoded },
896+
{ "eval_duration", predicted_ns },
897+
{ "prompt_tokens", n_prompt_tokens },
898+
{ "completion_tokens", n_decoded },
899+
{ "total_tokens", n_prompt_tokens + n_decoded },
900+
{ "id_slot", id_slot },
901+
{ "id", oaicompat_cmpl_id },
902+
{ "system_fingerprint", build_info },
903+
{ "object", "chat.completion" },
904+
};
905+
906+
// Ollama non-streaming includes the full content in the final response
907+
if (!stream) {
908+
res["message"]["content"] = content;
909+
}
910+
911+
return res;
912+
}
861913
};
862914

863915
struct server_task_result_cmpl_partial : server_task_result {
@@ -896,6 +948,8 @@ struct server_task_result_cmpl_partial : server_task_result {
896948
return to_json_oaicompat();
897949
case OAICOMPAT_TYPE_CHAT:
898950
return to_json_oaicompat_chat();
951+
case OAICOMPAT_TYPE_API_CHAT:
952+
return to_json_oaicompat_api_chat();
899953
default:
900954
GGML_ASSERT(false && "Invalid oaicompat_type");
901955
}
@@ -1007,6 +1061,24 @@ struct server_task_result_cmpl_partial : server_task_result {
10071061

10081062
return deltas;
10091063
}
1064+
1065+
json to_json_oaicompat_api_chat() {
1066+
std::time_t t = std::time(0);
1067+
{
1068+
// Ollama streaming partial response format
1069+
json res = {
1070+
{ "model", oaicompat_model },
1071+
{ "created_at", t },
1072+
{ "message",
1073+
{
1074+
{ "role", "assistant" }, { "content", content } // partial content
1075+
} },
1076+
{ "done", false }
1077+
};
1078+
// Ollama streaming responses don't seem to include timings or logprobs per partial token
1079+
return res;
1080+
}
1081+
}
10101082
};
10111083

10121084
struct server_task_result_embd : server_task_result {
@@ -4294,22 +4366,35 @@ int main(int argc, char ** argv) {
42944366
json res_json = result->to_json();
42954367
if (res_json.is_array()) {
42964368
for (const auto & res : res_json) {
4297-
if (!server_sent_event(sink, "data", res)) {
4298-
// sending failed (HTTP connection closed), cancel the generation
4299-
return false;
4369+
// ollama's /api/chat does not conform to the SEE format
4370+
if (oaicompat == OAICOMPAT_TYPE_API_CHAT) {
4371+
std::string s = safe_json_to_str(res) + "\n";
4372+
if (!sink.write(s.data(), s.size())) {
4373+
return false;
4374+
}
4375+
} else {
4376+
if (!server_sent_event(sink, "data", res)) {
4377+
// sending failed (HTTP connection closed), cancel the generation
4378+
return false;
4379+
}
43004380
}
43014381
}
43024382
return true;
43034383
} else {
4304-
return server_sent_event(sink, "data", res_json);
4384+
if (oaicompat == OAICOMPAT_TYPE_API_CHAT) {
4385+
std::string s = safe_json_to_str(res_json) + "\n";
4386+
return sink.write(s.data(), s.size());
4387+
} else {
4388+
return server_sent_event(sink, "data", res_json);
4389+
}
43054390
}
43064391
}, [&](const json & error_data) {
43074392
server_sent_event(sink, "error", error_data);
43084393
}, [&sink]() {
43094394
// note: do not use req.is_connection_closed here because req is already destroyed
43104395
return !sink.is_writable();
43114396
});
4312-
if (oaicompat != OAICOMPAT_TYPE_NONE) {
4397+
if (oaicompat != OAICOMPAT_TYPE_NONE && oaicompat != OAICOMPAT_TYPE_API_CHAT) {
43134398
static const std::string ev_done = "data: [DONE]\n\n";
43144399
sink.write(ev_done.data(), ev_done.size());
43154400
}
@@ -4436,6 +4521,16 @@ int main(int argc, char ** argv) {
44364521
}
44374522

44384523
auto body = json::parse(req.body);
4524+
// Ollama chat endpoint specific handling
4525+
auto OAICOMPAT_TYPE = OAICOMPAT_TYPE_CHAT;
4526+
if (req.path == "/api/chat") {
4527+
OAICOMPAT_TYPE = OAICOMPAT_TYPE_API_CHAT;
4528+
4529+
// Set default stream to true for /api/chat
4530+
if (!body.contains("stream")) {
4531+
body["stream"] = true;
4532+
}
4533+
}
44394534
std::vector<raw_buffer> files;
44404535
json data = oaicompat_chat_params_parse(
44414536
body,
@@ -4448,7 +4543,7 @@ int main(int argc, char ** argv) {
44484543
files,
44494544
req.is_connection_closed,
44504545
res,
4451-
OAICOMPAT_TYPE_CHAT);
4546+
OAICOMPAT_TYPE);
44524547
};
44534548

44544549
// same with handle_chat_completions, but without inference part

0 commit comments

Comments
 (0)