Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 65f9790

Browse files
committed
feat: rendering chat_template
1 parent 5414e02 commit 65f9790

File tree

15 files changed

+4321
-136
lines changed

15 files changed

+4321
-136
lines changed

engine/cli/commands/chat_completion_cmd.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,8 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
151151
json_data["model"] = model_handle;
152152
json_data["stream"] = true;
153153

154-
std::string json_payload = json_data.toStyledString();
155-
156-
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
154+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS,
155+
json_data.toStyledString().c_str());
157156

158157
std::string ai_chat;
159158
StreamingCallback callback;

engine/common/model_metadata.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "common/tokenizer.h"
4+
#include <sstream>
5+
6+
struct ModelMetadata {
7+
uint32_t version;
8+
uint64_t tensor_count;
9+
uint64_t metadata_kv_count;
10+
std::unique_ptr<Tokenizer> tokenizer;
11+
12+
std::string ToString() const {
13+
std::ostringstream ss;
14+
ss << "ModelMetadata {\n"
15+
<< "version: " << version << "\n"
16+
<< "tensor_count: " << tensor_count << "\n"
17+
<< "metadata_kv_count: " << metadata_kv_count << "\n"
18+
<< "tokenizer: ";
19+
20+
if (tokenizer) {
21+
ss << "\n" << tokenizer->ToString();
22+
} else {
23+
ss << "null";
24+
}
25+
26+
ss << "\n}";
27+
return ss.str();
28+
}
29+
};

engine/common/tokenizer.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#pragma once
2+
3+
#include <sstream>
4+
#include <string>
5+
6+
struct Tokenizer {
7+
std::string eos_token = "";
8+
bool add_eos_token = true;
9+
10+
std::string bos_token = "";
11+
bool add_bos_token = true;
12+
13+
std::string unknown_token = "";
14+
std::string padding_token = "";
15+
16+
std::string chat_template = "";
17+
18+
// Helper function for common fields
19+
std::string BaseToString() const {
20+
std::ostringstream ss;
21+
ss << "eos_token: \"" << eos_token << "\"\n"
22+
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
23+
<< "bos_token: \"" << bos_token << "\"\n"
24+
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
25+
<< "unknown_token: \"" << unknown_token << "\"\n"
26+
<< "padding_token: \"" << padding_token << "\"\n"
27+
<< "chat_template: \"" << chat_template << "\"";
28+
return ss.str();
29+
}
30+
31+
virtual ~Tokenizer() = default;
32+
33+
virtual std::string ToString() = 0;
34+
};
35+
36+
struct GgufTokenizer : public Tokenizer {
37+
std::string pre = "";
38+
39+
~GgufTokenizer() override = default;
40+
41+
std::string ToString() override {
42+
std::ostringstream ss;
43+
ss << "GgufTokenizer {\n";
44+
// Add base class members
45+
ss << BaseToString() << "\n";
46+
// Add derived class members
47+
ss << "pre: \"" << pre << "\"\n";
48+
ss << "}";
49+
return ss.str();
50+
}
51+
};
52+
53+
struct SafeTensorTokenizer : public Tokenizer {
54+
bool add_prefix_space = true;
55+
56+
~SafeTensorTokenizer() = default;
57+
58+
std::string ToString() override {
59+
std::ostringstream ss;
60+
ss << "SafeTensorTokenizer {\n";
61+
// Add base class members
62+
ss << BaseToString() << "\n";
63+
// Add derived class members
64+
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
65+
ss << "}";
66+
return ss.str();
67+
}
68+
};

engine/controllers/engines.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "utils/archive_utils.h"
44
#include "utils/cortex_utils.h"
55
#include "utils/engine_constants.h"
6+
#include "utils/jinja_utils.h"
67
#include "utils/logging_utils.h"
78
#include "utils/string_utils.h"
89

@@ -20,6 +21,41 @@ std::string NormalizeEngine(const std::string& engine) {
2021
};
2122
} // namespace
2223

24+
void Engines::TestJinja(
25+
const HttpRequestPtr& req,
26+
std::function<void(const HttpResponsePtr&)>&& callback) {
27+
auto body = req->getJsonObject();
28+
if (body == nullptr) {
29+
Json::Value ret;
30+
ret["message"] = "Body can't be empty";
31+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
32+
resp->setStatusCode(k400BadRequest);
33+
callback(resp);
34+
return;
35+
}
36+
37+
auto jinja = body->get("jinja", "").asString();
38+
auto data = body->get("data", {});
39+
auto bos_token = data.get("bos_token", "").asString();
40+
auto eos_token = data.get("eos_token", "").asString();
41+
42+
auto rendered_data = jinja::RenderTemplate(jinja, data, bos_token, eos_token);
43+
44+
if (rendered_data.has_error()) {
45+
Json::Value ret;
46+
ret["message"] = rendered_data.error();
47+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
48+
resp->setStatusCode(k400BadRequest);
49+
callback(resp);
50+
return;
51+
}
52+
// TODO: namh recheck all the api using this. because we have an issue with Germany locale before.
53+
auto resp = HttpResponse::newHttpResponse();
54+
resp->setBody(rendered_data.value());
55+
resp->setContentTypeCode(drogon::CT_TEXT_PLAIN);
56+
callback(resp);
57+
}
58+
2359
void Engines::ListEngine(
2460
const HttpRequestPtr& req,
2561
std::function<void(const HttpResponsePtr&)>&& callback) const {

engine/controllers/engines.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ class Engines : public drogon::HttpController<Engines, false> {
1212
public:
1313
METHOD_LIST_BEGIN
1414

15+
ADD_METHOD_TO(Engines::TestJinja, "/v1/jinja", Options, Post);
16+
1517
// install engine
1618
METHOD_ADD(Engines::InstallEngine, "/{1}/install", Options, Post);
1719
ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/{1}/install", Options,
@@ -110,6 +112,9 @@ class Engines : public drogon::HttpController<Engines, false> {
110112
std::function<void(const HttpResponsePtr&)>&& callback,
111113
const std::string& engine) const;
112114

115+
void TestJinja(const HttpRequestPtr& req,
116+
std::function<void(const HttpResponsePtr&)>&& callback);
117+
113118
void LoadEngine(const HttpRequestPtr& req,
114119
std::function<void(const HttpResponsePtr&)>&& callback,
115120
const std::string& engine);

engine/controllers/server.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "trantor/utils/Logger.h"
44
#include "utils/cortex_utils.h"
55
#include "utils/function_calling/common.h"
6-
#include "utils/http_util.h"
76

87
using namespace inferences;
98

@@ -27,6 +26,14 @@ void server::ChatCompletion(
2726
std::function<void(const HttpResponsePtr&)>&& callback) {
2827
LOG_DEBUG << "Start chat completion";
2928
auto json_body = req->getJsonObject();
29+
if (json_body == nullptr) {
30+
Json::Value ret;
31+
ret["message"] = "Body can't be empty";
32+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
33+
resp->setStatusCode(k400BadRequest);
34+
callback(resp);
35+
return;
36+
}
3037
bool is_stream = (*json_body).get("stream", false).asBool();
3138
auto model_id = (*json_body).get("model", "invalid_model").asString();
3239
auto engine_type = [this, &json_body]() -> std::string {

engine/services/engine_service.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <mutex>
55
#include <optional>
66
#include <string>
7-
#include <string_view>
87
#include <unordered_map>
98
#include <vector>
109

@@ -17,7 +16,6 @@
1716
#include "utils/cpuid/cpu_info.h"
1817
#include "utils/dylib.h"
1918
#include "utils/dylib_path_manager.h"
20-
#include "utils/engine_constants.h"
2119
#include "utils/github_release_utils.h"
2220
#include "utils/result.hpp"
2321
#include "utils/system_info_utils.h"
@@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
4846
struct EngineInfo {
4947
std::unique_ptr<cortex_cpp::dylib> dl;
5048
EngineV engine;
51-
#if defined(_WIN32)
52-
DLL_DIRECTORY_COOKIE cookie;
53-
DLL_DIRECTORY_COOKIE cuda_cookie;
54-
#endif
5549
};
5650

5751
std::mutex engines_mutex_;

engine/services/inference_service.cc

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#include "inference_service.h"
22
#include <drogon/HttpTypes.h>
33
#include "utils/engine_constants.h"
4+
#include "utils/file_manager_utils.h"
45
#include "utils/function_calling/common.h"
6+
#include "utils/gguf_metadata_reader.h"
7+
#include "utils/jinja_utils.h"
58

69
namespace services {
710
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
@@ -24,6 +27,41 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2427
return cpp::fail(std::make_pair(stt, res));
2528
}
2629

30+
{
31+
if (json_body->isMember("files") && !(*json_body)["files"].empty()) {
32+
auto file = (*json_body)["files"][0].asString();
33+
auto model_metadata_res = cortex_utils::ReadGgufMetadata(
34+
file_manager_utils::ToAbsoluteCortexDataPath(
35+
std::filesystem::path(file)));
36+
if (model_metadata_res.has_value()) {
37+
auto metadata = model_metadata_res.value().get();
38+
if (!metadata->tokenizer->chat_template.empty()) {
39+
auto messages = (*json_body)["messages"];
40+
Json::Value messages_jsoncpp(Json::arrayValue);
41+
for (auto message : messages) {
42+
messages_jsoncpp.append(message);
43+
}
44+
45+
Json::Value tools(Json::arrayValue);
46+
Json::Value template_data_json;
47+
template_data_json["messages"] = messages_jsoncpp;
48+
// template_data_json["tools"] = tools;
49+
50+
auto prompt_result = jinja::RenderTemplate(
51+
metadata->tokenizer->chat_template, template_data_json,
52+
metadata->tokenizer->bos_token, metadata->tokenizer->eos_token);
53+
if (prompt_result.has_value()) {
54+
(*json_body)["prompt"] = prompt_result.value();
55+
} else {
56+
CTL_ERR("Failed to render prompt: " + prompt_result.error());
57+
}
58+
}
59+
}
60+
}
61+
}
62+
63+
CTL_INF("Prompt is: " + json_body->get("prompt", "").asString());
64+
2765
auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
2866
if (!tool_choice.isNull()) {
2967
res["tool_choice"] = tool_choice;
@@ -297,4 +335,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
297335
}
298336
return true;
299337
}
300-
} // namespace services
338+
} // namespace services

engine/services/inference_service.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
#include <queue>
66
#include "services/engine_service.h"
77
#include "utils/result.hpp"
8-
#include "extensions/remote-engine/remote_engine.h"
8+
99
namespace services {
10+
1011
// Status and result
1112
using InferResult = std::pair<Json::Value, Json::Value>;
1213

engine/services/model_service.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "database/models.h"
1111
#include "hardware_service.h"
1212
#include "utils/cli_selection_utils.h"
13-
#include "utils/cortex_utils.h"
1413
#include "utils/engine_constants.h"
1514
#include "utils/file_manager_utils.h"
1615
#include "utils/huggingface_utils.h"

0 commit comments

Comments
 (0)