Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 8928a8c

Browse files
committed
fix: use different interface for remote engine
1 parent 90694c4 commit 8928a8c

File tree

7 files changed

+135
-87
lines changed

7 files changed

+135
-87
lines changed

engine/controllers/models.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Models : public drogon::HttpController<Models, false> {
3535
ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post);
3636
ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get);
3737
ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post);
38-
ADD_METHOD_TO(Models::GetRemoteModels, "/v1/remote/{1}", Get);
38+
ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get);
3939
METHOD_LIST_END
4040

4141
explicit Models(std::shared_ptr<ModelService> model_service,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#pragma once
4+
5+
#include <functional>
6+
#include <memory>
7+
8+
#include "json/value.h"
9+
#include "trantor/utils/Logger.h"
10+
class RemoteEngineI {
11+
public:
12+
virtual ~RemoteEngineI() {}
13+
14+
virtual void HandleChatCompletion(
15+
std::shared_ptr<Json::Value> json_body,
16+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
17+
virtual void HandleEmbedding(
18+
std::shared_ptr<Json::Value> json_body,
19+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
20+
virtual void LoadModel(
21+
std::shared_ptr<Json::Value> json_body,
22+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
23+
virtual void UnloadModel(
24+
std::shared_ptr<Json::Value> json_body,
25+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
26+
virtual void GetModelStatus(
27+
std::shared_ptr<Json::Value> json_body,
28+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
29+
30+
// Get list of running models
31+
virtual void GetModels(
32+
std::shared_ptr<Json::Value> jsonBody,
33+
std::function<void(Json::Value&&, Json::Value&&)>&& callback) = 0;
34+
35+
// Get available remote models
36+
virtual Json::Value GetRemoteModels() = 0;
37+
};

engine/extensions/remote-engine/remote_engine.cc

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -707,50 +707,9 @@ void RemoteEngine::HandleEmbedding(
707707
callback(Json::Value(), Json::Value());
708708
}
709709

710-
bool RemoteEngine::IsSupported(const std::string& f) {
711-
if (f == "HandleChatCompletion" || f == "LoadModel" || f == "UnloadModel" ||
712-
f == "GetModelStatus" || f == "GetModels" || f == "SetFileLogger" ||
713-
f == "SetLogLevel") {
714-
return true;
715-
}
716-
return false;
717-
}
718-
719-
bool RemoteEngine::SetFileLogger(int max_log_lines,
720-
const std::string& log_path) {
721-
if (!async_file_logger_) {
722-
async_file_logger_ = std::make_unique<trantor::FileLogger>();
723-
}
724-
725-
async_file_logger_->setFileName(log_path);
726-
async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines
727-
async_file_logger_->startLogging();
728-
trantor::Logger::setOutputFunction(
729-
[&](const char* msg, const uint64_t len) {
730-
if (async_file_logger_)
731-
async_file_logger_->output_(msg, len);
732-
},
733-
[&]() {
734-
if (async_file_logger_)
735-
async_file_logger_->flush();
736-
});
737-
freopen(log_path.c_str(), "w", stderr);
738-
freopen(log_path.c_str(), "w", stdout);
739-
return true;
740-
}
741-
742-
void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) {
743-
trantor::Logger::setLogLevel(log_level);
744-
}
745-
746710
Json::Value RemoteEngine::GetRemoteModels() {
747711
CTL_WRN("Not implemented yet!");
748712
return {};
749713
}
750714

751-
extern "C" {
752-
EngineI* get_engine() {
753-
return new RemoteEngine();
754-
}
755-
}
756715
} // namespace remote_engine

engine/extensions/remote-engine/remote_engine.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <shared_mutex>
88
#include <string>
99
#include <unordered_map>
10-
#include "cortex-common/EngineI.h"
10+
#include "cortex-common/remote_enginei.h"
1111
#include "extensions/remote-engine/template_renderer.h"
1212
#include "utils/engine_constants.h"
1313
#include "utils/file_logger.h"
@@ -31,7 +31,7 @@ struct CurlResponse {
3131
std::string error_message;
3232
};
3333

34-
class RemoteEngine : public EngineI {
34+
class RemoteEngine : public RemoteEngineI {
3535
protected:
3636
// Model configuration
3737
struct ModelConfig {
@@ -95,9 +95,7 @@ class RemoteEngine : public EngineI {
9595
void HandleEmbedding(
9696
std::shared_ptr<Json::Value> json_body,
9797
std::function<void(Json::Value&&, Json::Value&&)>&& callback) override;
98-
bool IsSupported(const std::string& feature) override;
99-
bool SetFileLogger(int max_log_lines, const std::string& log_path) override;
100-
void SetLogLevel(trantor::Logger::LogLevel logLevel) override;
98+
10199
Json::Value GetRemoteModels() override;
102100
};
103101

engine/services/engine_service.cc

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -694,21 +694,6 @@ cpp::result<void, std::string> EngineService::LoadEngine(
694694
engines_[engine_name].engine = new remote_engine::AnthropicEngine();
695695
}
696696

697-
auto& en = std::get<EngineI*>(engines_[ne].engine);
698-
auto config = file_manager_utils::GetCortexConfig();
699-
if (en->IsSupported("SetFileLogger")) {
700-
en->SetFileLogger(config.maxLogLines,
701-
(std::filesystem::path(config.logFolderPath) /
702-
std::filesystem::path(config.logLlamaCppPath))
703-
.string());
704-
} else {
705-
CTL_WRN("Method SetFileLogger is not supported yet");
706-
}
707-
if (en->IsSupported("SetLogLevel")) {
708-
en->SetLogLevel(trantor::Logger::logLevel());
709-
} else {
710-
CTL_WRN("Method SetLogLevel is not supported yet");
711-
}
712697
CTL_INF("Loaded engine: " << engine_name);
713698
return {};
714699
}
@@ -883,8 +868,11 @@ cpp::result<void, std::string> EngineService::UnloadEngine(
883868
if (!IsEngineLoaded(ne)) {
884869
return cpp::fail("Engine " + ne + " is not loaded yet!");
885870
}
886-
EngineI* e = std::get<EngineI*>(engines_[ne].engine);
887-
delete e;
871+
if (std::holds_alternative<EngineI*>(engines_[ne].engine)) {
872+
delete std::get<EngineI*>(engines_[ne].engine);
873+
} else {
874+
delete std::get<RemoteEngineI*>(engines_[ne].engine);
875+
}
888876

889877
#if defined(_WIN32)
890878
if (!RemoveDllDirectory(engines_[ne].cookie)) {
@@ -1100,7 +1088,22 @@ cpp::result<Json::Value, std::string> EngineService::GetRemoteModels(
11001088
return cpp::fail(r.error());
11011089
}
11021090

1103-
auto& e = std::get<EngineI*>(engines_[engine_name].engine);
1091+
if (!IsEngineLoaded(engine_name)) {
1092+
auto exist_engine = GetEngineByNameAndVariant(engine_name);
1093+
if (exist_engine.has_error()) {
1094+
return cpp::fail("Remote engine '" + engine_name + "' is not installed");
1095+
}
1096+
1097+
if (engine_name == kOpenAiEngine) {
1098+
engines_[engine_name].engine = new remote_engine::OpenAiEngine();
1099+
} else {
1100+
engines_[engine_name].engine = new remote_engine::AnthropicEngine();
1101+
}
1102+
1103+
CTL_INF("Loaded engine: " << engine_name);
1104+
}
1105+
1106+
auto& e = std::get<RemoteEngineI*>(engines_[engine_name].engine);
11041107
auto res = e->GetRemoteModels();
11051108
if (!res["error"].isNull()) {
11061109
return cpp::fail(res["error"].asString());

engine/services/engine_service.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "common/engine_servicei.h"
1212
#include "cortex-common/EngineI.h"
1313
#include "cortex-common/cortexpythoni.h"
14+
#include "cortex-common/remote_enginei.h"
1415
#include "database/engines.h"
1516
#include "extensions/remote-engine/remote_engine.h"
1617
#include "services/download_service.h"
@@ -37,7 +38,7 @@ struct EngineUpdateResult {
3738
}
3839
};
3940

40-
using EngineV = std::variant<EngineI*, CortexPythonEngineI*>;
41+
using EngineV = std::variant<EngineI*, CortexPythonEngineI*, RemoteEngineI*>;
4142

4243
class EngineService : public EngineServiceI {
4344
private:

engine/services/inference_service.cc

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,26 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2424
return cpp::fail(std::make_pair(stt, res));
2525
}
2626

27-
auto engine = std::get<EngineI*>(engine_result.value());
28-
engine->HandleChatCompletion(
29-
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
30-
if (!tool_choice.isNull()) {
31-
res["tool_choice"] = tool_choice;
32-
}
33-
q->push(std::make_pair(status, res));
34-
});
27+
if (std::holds_alternative<EngineI*>(engine_result.value())) {
28+
std::get<EngineI*>(engine_result.value())
29+
->HandleChatCompletion(
30+
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
31+
if (!tool_choice.isNull()) {
32+
res["tool_choice"] = tool_choice;
33+
}
34+
q->push(std::make_pair(status, res));
35+
});
36+
} else {
37+
std::get<RemoteEngineI*>(engine_result.value())
38+
->HandleChatCompletion(
39+
json_body, [q, tool_choice](Json::Value status, Json::Value res) {
40+
if (!tool_choice.isNull()) {
41+
res["tool_choice"] = tool_choice;
42+
}
43+
q->push(std::make_pair(status, res));
44+
});
45+
}
46+
3547
return {};
3648
}
3749

@@ -53,10 +65,18 @@ cpp::result<void, InferResult> InferenceService::HandleEmbedding(
5365
LOG_WARN << "Engine is not loaded yet";
5466
return cpp::fail(std::make_pair(stt, res));
5567
}
56-
auto engine = std::get<EngineI*>(engine_result.value());
57-
engine->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
58-
q->push(std::make_pair(status, res));
59-
});
68+
69+
if (std::holds_alternative<EngineI*>(engine_result.value())) {
70+
std::get<EngineI*>(engine_result.value())
71+
->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
72+
q->push(std::make_pair(status, res));
73+
});
74+
} else {
75+
std::get<RemoteEngineI*>(engine_result.value())
76+
->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) {
77+
q->push(std::make_pair(status, res));
78+
});
79+
}
6080
return {};
6181
}
6282

@@ -83,11 +103,20 @@ InferResult InferenceService::LoadModel(
83103

84104
// might need mutex here
85105
auto engine_result = engine_service_->GetLoadedEngine(engine_type);
86-
auto engine = std::get<EngineI*>(engine_result.value());
87-
engine->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
88-
stt = status;
89-
r = res;
90-
});
106+
107+
if (std::holds_alternative<EngineI*>(engine_result.value())) {
108+
std::get<EngineI*>(engine_result.value())
109+
->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
110+
stt = status;
111+
r = res;
112+
});
113+
} else {
114+
std::get<RemoteEngineI*>(engine_result.value())
115+
->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) {
116+
stt = status;
117+
r = res;
118+
});
119+
}
91120
return std::make_pair(stt, r);
92121
}
93122

@@ -110,12 +139,22 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name,
110139
json_body["model"] = model_id;
111140

112141
LOG_TRACE << "Start unload model";
113-
auto engine = std::get<EngineI*>(engine_result.value());
114-
engine->UnloadModel(std::make_shared<Json::Value>(json_body),
142+
if (std::holds_alternative<EngineI*>(engine_result.value())) {
143+
std::get<EngineI*>(engine_result.value())
144+
->UnloadModel(std::make_shared<Json::Value>(json_body),
145+
[&r, &stt](Json::Value status, Json::Value res) {
146+
stt = status;
147+
r = res;
148+
});
149+
} else {
150+
std::get<RemoteEngineI*>(engine_result.value())
151+
->UnloadModel(std::make_shared<Json::Value>(json_body),
115152
[&r, &stt](Json::Value status, Json::Value res) {
116153
stt = status;
117154
r = res;
118155
});
156+
}
157+
119158
return std::make_pair(stt, r);
120159
}
121160

@@ -141,12 +180,23 @@ InferResult InferenceService::GetModelStatus(
141180
}
142181

143182
LOG_TRACE << "Start to get model status";
144-
auto engine = std::get<EngineI*>(engine_result.value());
145-
engine->GetModelStatus(json_body,
183+
184+
if (std::holds_alternative<EngineI*>(engine_result.value())) {
185+
std::get<EngineI*>(engine_result.value())
186+
->GetModelStatus(json_body,
187+
[&stt, &r](Json::Value status, Json::Value res) {
188+
stt = status;
189+
r = res;
190+
});
191+
} else {
192+
std::get<RemoteEngineI*>(engine_result.value())
193+
->GetModelStatus(json_body,
146194
[&stt, &r](Json::Value status, Json::Value res) {
147195
stt = status;
148196
r = res;
149197
});
198+
}
199+
150200
return std::make_pair(stt, r);
151201
}
152202

0 commit comments

Comments
 (0)