From 190d40b9a5c98c07d8520fbbd91fa7e8825f1bb6 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Mon, 11 Nov 2024 23:41:37 +0700 Subject: [PATCH 01/33] Init remote engine --- .../remote-engine/TemplateRenderer.cc | 91 ++++ .../remote-engine/TemplateRenderer.h | 28 ++ .../extensions/remote-engine/remote_engine.cc | 425 ++++++++++++++++++ .../extensions/remote-engine/remote_engine.h | 84 ++++ engine/services/inference_service.h | 2 +- engine/vcpkg.json | 3 +- 6 files changed, 631 insertions(+), 2 deletions(-) create mode 100644 engine/extensions/remote-engine/TemplateRenderer.cc create mode 100644 engine/extensions/remote-engine/TemplateRenderer.h create mode 100644 engine/extensions/remote-engine/remote_engine.cc create mode 100644 engine/extensions/remote-engine/remote_engine.h diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc new file mode 100644 index 000000000..284b64bd3 --- /dev/null +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -0,0 +1,91 @@ +#include "TemplateRenderer.h" +#include +#include +#include +#include + +TemplateRenderer::TemplateRenderer() { + // Configure Inja environment + env_.set_trim_blocks(true); + env_.set_lstrip_blocks(true); + + // Add tojson function for all value types + env_.add_callback("tojson", 1, [](inja::Arguments& args) { + if (args.empty()) { + return inja::json(nullptr); + } + const auto& value = *args[0]; + + if (value.is_string()) { + return inja::json(std::string("\"") + value.get() + "\""); + } + return value; + }); +} + +std::string TemplateRenderer::jsonToString(const Json::Value& value) { + Json::StreamWriterBuilder writer; + return Json::writeString(writer, value); +} + +bool TemplateRenderer::validateJson(const std::string& jsonStr) { + Json::Value root; + Json::CharReaderBuilder builder; + std::string errors; + std::istringstream iss(jsonStr); + return Json::parseFromStream(builder, iss, &root, &errors); +} + +std::string TemplateRenderer::render(const std::string& tmpl, const Json::Value& data) { + try { + // Create the input data structure expected by the template + Json::Value template_data; + template_data["input_request"] = data; + + // Convert to string for logging + std::string dataStr = jsonToString(template_data); + + // Debug output + LOG_DEBUG << "Template: " << tmpl; + LOG_DEBUG << "Data: " << dataStr; + + // Convert to inja's json format + auto inja_data = inja::json::parse(dataStr); + + // Render template + std::string result = env_.render(tmpl, inja_data); + + // Clean up any potential double quotes in JSON strings + result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); + + LOG_DEBUG << "Result: " << result; + + // Validate JSON + if (!validateJson(result)) { + throw std::runtime_error("Invalid JSON in rendered template"); + } + + return result; + } + catch (const std::exception& e) { + LOG_ERROR << "Template rendering failed: " << e.what(); + LOG_ERROR << "Template: " << tmpl; + throw std::runtime_error(std::string("Template rendering failed: ") + e.what()); + } +} + +std::string TemplateRenderer::renderFile(const std::string& template_path, const Json::Value& data) { + try { + // Convert JsonCpp Value to string + std::string dataStr = jsonToString(data); + + // Parse as inja json + auto inja_data = inja::json::parse(dataStr); + + // Load and render template + return env_.render_file(template_path, inja_data); + } + catch (const std::exception& e) { + throw std::runtime_error(std::string("Template file rendering failed: ") + e.what()); + } +} \ No newline at end of file diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/TemplateRenderer.h new file mode 100644 index 000000000..0a382eee6 --- /dev/null +++ b/engine/extensions/remote-engine/TemplateRenderer.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include +#include "json/json.h" +#include "trantor/utils/Logger.h" + +class TemplateRenderer { +public: + TemplateRenderer(); + ~TemplateRenderer() = default; + + // Render template with data + std::string render(const std::string& tmpl, const Json::Value& data); + + // Load template from file and render + std::string renderFile(const std::string& template_path, const Json::Value& data); + +private: + // Helper function to convert JsonCpp Value to string representation + static std::string jsonToString(const Json::Value& value); + + // Helper function to validate JSON string + static bool validateJson(const std::string& jsonStr); + + inja::Environment env_; +}; \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc new file mode 100644 index 000000000..bab63bb75 --- /dev/null +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -0,0 +1,425 @@ +#include "remote_engine.h" +#include +#include +#include +#include + +constexpr const int k200OK = 200; +constexpr const int k400BadRequest = 400; +constexpr const int k409Conflict = 409; +constexpr const int k500InternalServerError = 500; +constexpr const int kFileLoggerOption = 0; + +std::string ReplaceApiKeyPlaceholder(const std::string& templateStr, + const std::string& apiKey) { + const std::string placeholder = "{{api_key}}"; + std::string result = templateStr; + size_t pos = result.find(placeholder); + + if (pos != std::string::npos) { + result.replace(pos, placeholder.length(), apiKey); + } + + return result; +} + +static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, + std::string* data) { + data->append(ptr, size * nmemb); + return size * nmemb; +} + +RemoteEngine::RemoteEngine() { + curl_global_init(CURL_GLOBAL_ALL); +} + +RemoteEngine::~RemoteEngine() { + curl_global_cleanup(); +} + +RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( + const std::string& model_id) { + std::shared_lock lock(models_mutex_); + auto it = models_.find(model_id); + if (it != models_.end()) { + return &it->second; + } + return nullptr; +} + +CurlResponse RemoteEngine::MakeGetModelsRequest() { + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = metadata_["get_models_url"].asString(); + + struct curl_slist* headers = nullptr; + + headers = curl_slist_append(headers, api_key_template_.c_str()); + std::cout << "api_key: " << api_key_template_ << std::endl; + + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + std::string response_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + } else { + response.body = response_string; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +CurlResponse RemoteEngine::MakeChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::string& method) { + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + std::string full_url = + config.transform_req["chat_completions"]["url"].as(); + + struct curl_slist* headers = nullptr; + if (!config.api_key.empty()) { + + headers = curl_slist_append(headers, api_key_template_.c_str()); + std::cout << "api_key: " << api_key_template_ << std::endl; + } + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + if (method == "POST") { + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + } + + std::string response_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + } else { + response.body = response_string; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +bool RemoteEngine::LoadModelConfig(const std::string& model_id, + const std::string& yaml_path, + const std::string& api_key) { + try { + YAML::Node config = YAML::LoadFile(yaml_path); + + ModelConfig model_config; + model_config.model_id = model_id; + + // Required fields + if (!config["api_key_template"]) { + LOG_ERROR << "Missing required fields in config for model " << model_id; + return false; + } + + model_config.api_key = api_key; + // model_config.url = ; + // Optional fields + if (config["api_key_template"]) { + api_key_template_ = ReplaceApiKeyPlaceholder( + config["api_key_template"].as(), api_key); + } + if (config["TransformReq"]) { + model_config.transform_req = config["TransformReq"]; + } else { + LOG_WARN << "Missing TransformReq in config for model " << model_id; + } + if (config["TransformResp"]) { + model_config.transform_resp = config["TransformResp"]; + } else { + LOG_WARN << "Missing TransformResp in config for model " << model_id; + } + + model_config.is_loaded = true; + + // Thread-safe update of models map + { + std::unique_lock lock(models_mutex_); + models_[model_id] = std::move(model_config); + } + + return true; + } catch (const YAML::Exception& e) { + LOG_ERROR << "Failed to load config for model " << model_id << ": " + << e.what(); + return false; + } +} + +void RemoteEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json)); +} + +void RemoteEngine::LoadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model_id") || !json_body->isMember("model_path") || + !json_body->isMember("api_key")) { + Json::Value error; + error["error"] = "Missing required fields: model_id or model_path"; + callback(Json::Value(), std::move(error)); + return; + } + + const std::string& model_id = (*json_body)["model_id"].asString(); + const std::string& model_path = (*json_body)["model_path"].asString(); + const std::string& api_key = (*json_body)["api_key"].asString(); + + if (!LoadModelConfig(model_id, model_path, api_key)) { + Json::Value error; + error["error"] = "Failed to load model configuration"; + callback(Json::Value(), std::move(error)); + return; + } + if (json_body->isMember("metadata")) { + metadata_ = (*json_body)["metadata"]; + } + + Json::Value response; + response["status"] = "Model loaded successfully"; + callback(Json::Value(), std::move(response)); +} + +void RemoteEngine::UnloadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model_id")) { + Json::Value error; + error["error"] = "Missing required field: model_id"; + callback(Json::Value(), std::move(error)); + return; + } + + const std::string& model_id = (*json_body)["model_id"].asString(); + + { + std::unique_lock lock(models_mutex_); + models_.erase(model_id); + } + + Json::Value response; + response["status"] = "Model unloaded successfully"; + callback(std::move(response), Json::Value()); +} + +void RemoteEngine::HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Missing required fields: model"; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model_id = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model_id); + + if (!model_config || !model_config->is_loaded) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Model not found or not loaded: " + model_id; + callback(std::move(status), std::move(error)); + return; + } + + Json::FastWriter writer; + std::string request_body = writer.write((*json_body)); + std::cout << "template: " + << model_config->transform_req["chat_completions"]["template"] + .as() + << std::endl; + std::string result = renderer_.render( + model_config->transform_req["chat_completions"]["template"] + .as(), + (*json_body)); + + auto response = MakeChatCompletionRequest(*model_config, result); + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json)); +} + +void RemoteEngine::GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model_id")) { + Json::Value error; + error["error"] = "Missing required field: model_id"; + callback(Json::Value(), std::move(error)); + return; + } + + const std::string& model_id = (*json_body)["model_id"].asString(); + auto* model_config = GetModelConfig(model_id); + + if (!model_config) { + Json::Value error; + error["error"] = "Model not found: " + model_id; + callback(Json::Value(), std::move(error)); + return; + } + + Json::Value response; + response["model_id"] = model_id; + response["is_loaded"] = model_config->is_loaded; + response["url"] = model_config->url; + + callback(std::move(response), Json::Value()); +} + +// Implement remaining virtual functions +void RemoteEngine::HandleEmbedding( + std::shared_ptr, + std::function&& callback) { + callback(Json::Value(), Json::Value()); +} + +bool RemoteEngine::IsSupported(const std::string& f) { + if (f == "HandleChatCompletion" || f == "LoadModel" || f == "UnloadModel" || + f == "GetModelStatus" || f == "GetModels" || f == "SetFileLogger" || + f == "SetLogLevel") { + return true; + } + return false; +} + +bool RemoteEngine::SetFileLogger(int max_log_lines, + const std::string& log_path) { + if (!async_file_logger_) { + async_file_logger_ = std::make_unique(); + } + + async_file_logger_->setFileName(log_path); + async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines + async_file_logger_->startLogging(); + trantor::Logger::setOutputFunction( + [&](const char* msg, const uint64_t len) { + if (async_file_logger_) + async_file_logger_->output_(msg, len); + }, + [&]() { + if (async_file_logger_) + async_file_logger_->flush(); + }); + freopen(log_path.c_str(), "w", stderr); + freopen(log_path.c_str(), "w", stdout); +} + +void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { + trantor::Logger::setLogLevel(log_level); +} + +extern "C" { +EngineI* get_engine() { + return new RemoteEngine(); +} +} \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h new file mode 100644 index 000000000..e55107286 --- /dev/null +++ b/engine/extensions/remote-engine/remote_engine.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "cortex-common/EngineI.h" +#include "extensions/remote-engine/TemplateRenderer.h" +#include "utils/file_logger.h" +// Helper for CURL response +struct CurlResponse { + std::string body; + bool error{false}; + std::string error_message; +}; + +class RemoteEngine : public EngineI { + private: + // Model configuration + struct ModelConfig { + std::string model_id; + std::string api_key; + std::string url; + YAML::Node transform_req; + YAML::Node transform_resp; + bool is_loaded{false}; + }; + + // Thread-safe model config storage + mutable std::shared_mutex models_mutex_; + std::unordered_map models_; + TemplateRenderer renderer_; + Json::Value metadata_; + std::string api_key_template_; + std::unique_ptr async_file_logger_; + + // Helper functions + CurlResponse MakeChatCompletionRequest(const ModelConfig& config, + const std::string& body, + const std::string& method = "POST"); + CurlResponse MakeGetModelsRequest(); + + // Internal model management + bool LoadModelConfig(const std::string& model_id, + const std::string& yaml_path, + const std::string& api_key); + ModelConfig* GetModelConfig(const std::string& model_id); + + public: + RemoteEngine(); + ~RemoteEngine(); + + // Main interface implementations + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) override; + + void LoadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) override; + + // Other required virtual functions + void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) override; + bool IsSupported(const std::string& feature) override; + bool SetFileLogger(int max_log_lines, const std::string& log_path) override; + void SetLogLevel(trantor::Logger::LogLevel logLevel) override; +}; \ No newline at end of file diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 7c09156ff..94097132a 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -5,7 +5,7 @@ #include #include "services/engine_service.h" #include "utils/result.hpp" - +#include "extensions/remote-engine/remote_engine.h" namespace services { // Status and result using InferResult = std::pair; diff --git a/engine/vcpkg.json b/engine/vcpkg.json index 64e6f6d26..c564e01c1 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -16,6 +16,7 @@ "eventpp", "sqlitecpp", "trantor", - "indicators" + "indicators", + "inja" ] } From c6124ba7c80183e31b628f76b901d16e78d29cd6 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Mon, 11 Nov 2024 23:56:14 +0700 Subject: [PATCH 02/33] Fix: CI build windows --- .../remote-engine/TemplateRenderer.cc | 126 +++++++++++++----- .../remote-engine/TemplateRenderer.h | 14 +- 2 files changed, 97 insertions(+), 43 deletions(-) diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 284b64bd3..89515ca38 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -1,8 +1,7 @@ #include "TemplateRenderer.h" #include #include -#include -#include + TemplateRenderer::TemplateRenderer() { // Configure Inja environment @@ -12,48 +11,32 @@ TemplateRenderer::TemplateRenderer() { // Add tojson function for all value types env_.add_callback("tojson", 1, [](inja::Arguments& args) { if (args.empty()) { - return inja::json(nullptr); + return nlohmann::json(nullptr); } const auto& value = *args[0]; if (value.is_string()) { - return inja::json(std::string("\"") + value.get() + "\""); + return nlohmann::json(std::string("\"") + value.get() + "\""); } return value; }); } -std::string TemplateRenderer::jsonToString(const Json::Value& value) { - Json::StreamWriterBuilder writer; - return Json::writeString(writer, value); -} - -bool TemplateRenderer::validateJson(const std::string& jsonStr) { - Json::Value root; - Json::CharReaderBuilder builder; - std::string errors; - std::istringstream iss(jsonStr); - return Json::parseFromStream(builder, iss, &root, &errors); -} - std::string TemplateRenderer::render(const std::string& tmpl, const Json::Value& data) { try { - // Create the input data structure expected by the template - Json::Value template_data; - template_data["input_request"] = data; + // Convert Json::Value to nlohmann::json + auto json_data = convertJsonValue(data); - // Convert to string for logging - std::string dataStr = jsonToString(template_data); + // Create the input data structure expected by the template + nlohmann::json template_data; + template_data["input_request"] = json_data; // Debug output LOG_DEBUG << "Template: " << tmpl; - LOG_DEBUG << "Data: " << dataStr; - - // Convert to inja's json format - auto inja_data = inja::json::parse(dataStr); + LOG_DEBUG << "Data: " << template_data.dump(2); // Render template - std::string result = env_.render(tmpl, inja_data); + std::string result = env_.render(tmpl, template_data); // Clean up any potential double quotes in JSON strings result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); @@ -61,9 +44,7 @@ std::string TemplateRenderer::render(const std::string& tmpl, const Json::Value& LOG_DEBUG << "Result: " << result; // Validate JSON - if (!validateJson(result)) { - throw std::runtime_error("Invalid JSON in rendered template"); - } + auto parsed = nlohmann::json::parse(result); return result; } @@ -74,16 +55,89 @@ std::string TemplateRenderer::render(const std::string& tmpl, const Json::Value& } } +nlohmann::json TemplateRenderer::convertJsonValue(const Json::Value& input) { + if (input.isNull()) { + return nullptr; + } + else if (input.isBool()) { + return input.asBool(); + } + else if (input.isInt()) { + return input.asInt(); + } + else if (input.isUInt()) { + return input.asUInt(); + } + else if (input.isDouble()) { + return input.asDouble(); + } + else if (input.isString()) { + return input.asString(); + } + else if (input.isArray()) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& element : input) { + arr.push_back(convertJsonValue(element)); + } + return arr; + } + else if (input.isObject()) { + nlohmann::json obj = nlohmann::json::object(); + for (const auto& key : input.getMemberNames()) { + obj[key] = convertJsonValue(input[key]); + } + return obj; + } + return nullptr; +} + +Json::Value TemplateRenderer::convertNlohmannJson(const nlohmann::json& input) { + if (input.is_null()) { + return Json::Value(); + } + else if (input.is_boolean()) { + return Json::Value(input.get()); + } + else if (input.is_number_integer()) { + return Json::Value(input.get()); + } + else if (input.is_number_unsigned()) { + return Json::Value(input.get()); + } + else if (input.is_number_float()) { + return Json::Value(input.get()); + } + else if (input.is_string()) { + return Json::Value(input.get()); + } + else if (input.is_array()) { + Json::Value arr(Json::arrayValue); + for (const auto& element : input) { + arr.append(convertNlohmannJson(element)); + } + return arr; + } + else if (input.is_object()) { + Json::Value obj(Json::objectValue); + for (auto it = input.begin(); it != input.end(); ++it) { + obj[it.key()] = convertNlohmannJson(it.value()); + } + return obj; + } + return Json::Value(); +} + + + + + std::string TemplateRenderer::renderFile(const std::string& template_path, const Json::Value& data) { try { - // Convert JsonCpp Value to string - std::string dataStr = jsonToString(data); - - // Parse as inja json - auto inja_data = inja::json::parse(dataStr); + // Convert Json::Value to nlohmann::json + auto json_data = convertJsonValue(data); // Load and render template - return env_.render_file(template_path, inja_data); + return env_.render_file(template_path, json_data); } catch (const std::exception& e) { throw std::runtime_error(std::string("Template file rendering failed: ") + e.what()); diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/TemplateRenderer.h index 0a382eee6..413108fb6 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.h +++ b/engine/extensions/remote-engine/TemplateRenderer.h @@ -2,15 +2,21 @@ #include #include +#include #include #include "json/json.h" #include "trantor/utils/Logger.h" - class TemplateRenderer { public: TemplateRenderer(); ~TemplateRenderer() = default; + // Convert Json::Value to nlohmann::json + static nlohmann::json convertJsonValue(const Json::Value& input); + + // Convert nlohmann::json to Json::Value + static Json::Value convertNlohmannJson(const nlohmann::json& input); + // Render template with data std::string render(const std::string& tmpl, const Json::Value& data); @@ -18,11 +24,5 @@ class TemplateRenderer { std::string renderFile(const std::string& template_path, const Json::Value& data); private: - // Helper function to convert JsonCpp Value to string representation - static std::string jsonToString(const Json::Value& value); - - // Helper function to validate JSON string - static bool validateJson(const std::string& jsonStr); - inja::Environment env_; }; \ No newline at end of file From 135c41ededef9bcd61e4379eabee0c8590ba3e1b Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 12 Nov 2024 00:02:28 +0700 Subject: [PATCH 03/33] Fix: CI build windows --- .../remote-engine/TemplateRenderer.cc | 236 ++++++++---------- .../remote-engine/TemplateRenderer.h | 36 +-- 2 files changed, 130 insertions(+), 142 deletions(-) diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 89515ca38..67b7e13d3 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -1,145 +1,129 @@ #include "TemplateRenderer.h" -#include #include - +#include TemplateRenderer::TemplateRenderer() { - // Configure Inja environment - env_.set_trim_blocks(true); - env_.set_lstrip_blocks(true); - - // Add tojson function for all value types - env_.add_callback("tojson", 1, [](inja::Arguments& args) { - if (args.empty()) { - return nlohmann::json(nullptr); - } - const auto& value = *args[0]; - - if (value.is_string()) { - return nlohmann::json(std::string("\"") + value.get() + "\""); - } - return value; - }); -} + // Configure Inja environment + env_.set_trim_blocks(true); + env_.set_lstrip_blocks(true); -std::string TemplateRenderer::render(const std::string& tmpl, const Json::Value& data) { - try { - // Convert Json::Value to nlohmann::json - auto json_data = convertJsonValue(data); - - // Create the input data structure expected by the template - nlohmann::json template_data; - template_data["input_request"] = json_data; - - // Debug output - LOG_DEBUG << "Template: " << tmpl; - LOG_DEBUG << "Data: " << template_data.dump(2); - - // Render template - std::string result = env_.render(tmpl, template_data); - - // Clean up any potential double quotes in JSON strings - result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); - - LOG_DEBUG << "Result: " << result; - - // Validate JSON - auto parsed = nlohmann::json::parse(result); - - return result; + // Add tojson function for all value types + env_.add_callback("tojson", 1, [](inja::Arguments& args) { + if (args.empty()) { + return nlohmann::json(nullptr); } - catch (const std::exception& e) { - LOG_ERROR << "Template rendering failed: " << e.what(); - LOG_ERROR << "Template: " << tmpl; - throw std::runtime_error(std::string("Template rendering failed: ") + e.what()); + const auto& value = *args[0]; + + if (value.is_string()) { + return nlohmann::json(std::string("\"") + value.get() + + "\""); } + return value; + }); +} + +std::string TemplateRenderer::render(const std::string& tmpl, + const Json::Value& data) { + try { + // Convert Json::Value to nlohmann::json + auto json_data = convertJsonValue(data); + + // Create the input data structure expected by the template + nlohmann::json template_data; + template_data["input_request"] = json_data; + + // Debug output + LOG_DEBUG << "Template: " << tmpl; + LOG_DEBUG << "Data: " << template_data.dump(2); + + // Render template + std::string result = env_.render(tmpl, template_data); + + // Clean up any potential double quotes in JSON strings + result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); + + LOG_DEBUG << "Result: " << result; + + // Validate JSON + auto parsed = nlohmann::json::parse(result); + + return result; + } catch (const std::exception& e) { + LOG_ERROR << "Template rendering failed: " << e.what(); + LOG_ERROR << "Template: " << tmpl; + throw std::runtime_error(std::string("Template rendering failed: ") + + e.what()); + } } nlohmann::json TemplateRenderer::convertJsonValue(const Json::Value& input) { - if (input.isNull()) { - return nullptr; - } - else if (input.isBool()) { - return input.asBool(); - } - else if (input.isInt()) { - return input.asInt(); - } - else if (input.isUInt()) { - return input.asUInt(); - } - else if (input.isDouble()) { - return input.asDouble(); - } - else if (input.isString()) { - return input.asString(); - } - else if (input.isArray()) { - nlohmann::json arr = nlohmann::json::array(); - for (const auto& element : input) { - arr.push_back(convertJsonValue(element)); - } - return arr; - } - else if (input.isObject()) { - nlohmann::json obj = nlohmann::json::object(); - for (const auto& key : input.getMemberNames()) { - obj[key] = convertJsonValue(input[key]); - } - return obj; - } + if (input.isNull()) { return nullptr; + } else if (input.isBool()) { + return input.asBool(); + } else if (input.isInt()) { + return input.asInt(); + } else if (input.isUInt()) { + return input.asUInt(); + } else if (input.isDouble()) { + return input.asDouble(); + } else if (input.isString()) { + return input.asString(); + } else if (input.isArray()) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& element : input) { + arr.push_back(convertJsonValue(element)); + } + return arr; + } else if (input.isObject()) { + nlohmann::json obj = nlohmann::json::object(); + for (const auto& key : input.getMemberNames()) { + obj[key] = convertJsonValue(input[key]); + } + return obj; + } + return nullptr; } Json::Value TemplateRenderer::convertNlohmannJson(const nlohmann::json& input) { - if (input.is_null()) { - return Json::Value(); - } - else if (input.is_boolean()) { - return Json::Value(input.get()); - } - else if (input.is_number_integer()) { - return Json::Value(input.get()); - } - else if (input.is_number_unsigned()) { - return Json::Value(input.get()); - } - else if (input.is_number_float()) { - return Json::Value(input.get()); - } - else if (input.is_string()) { - return Json::Value(input.get()); - } - else if (input.is_array()) { - Json::Value arr(Json::arrayValue); - for (const auto& element : input) { - arr.append(convertNlohmannJson(element)); - } - return arr; - } - else if (input.is_object()) { - Json::Value obj(Json::objectValue); - for (auto it = input.begin(); it != input.end(); ++it) { - obj[it.key()] = convertNlohmannJson(it.value()); - } - return obj; - } + if (input.is_null()) { return Json::Value(); + } else if (input.is_boolean()) { + return Json::Value(input.get()); + } else if (input.is_number_integer()) { + return Json::Value(input.get()); + } else if (input.is_number_unsigned()) { + return Json::Value(input.get()); + } else if (input.is_number_float()) { + return Json::Value(input.get()); + } else if (input.is_string()) { + return Json::Value(input.get()); + } else if (input.is_array()) { + Json::Value arr(Json::arrayValue); + for (const auto& element : input) { + arr.append(convertNlohmannJson(element)); + } + return arr; + } else if (input.is_object()) { + Json::Value obj(Json::objectValue); + for (auto it = input.begin(); it != input.end(); ++it) { + obj[it.key()] = convertNlohmannJson(it.value()); + } + return obj; + } + return Json::Value(); } +std::string TemplateRenderer::renderFile(const std::string& template_path, + const Json::Value& data) { + try { + // Convert Json::Value to nlohmann::json + auto json_data = convertJsonValue(data); - - - -std::string TemplateRenderer::renderFile(const std::string& template_path, const Json::Value& data) { - try { - // Convert Json::Value to nlohmann::json - auto json_data = convertJsonValue(data); - - // Load and render template - return env_.render_file(template_path, json_data); - } - catch (const std::exception& e) { - throw std::runtime_error(std::string("Template file rendering failed: ") + e.what()); - } + // Load and render template + return env_.render_file(template_path, json_data); + } catch (const std::exception& e) { + throw std::runtime_error(std::string("Template file rendering failed: ") + + e.what()); + } } \ No newline at end of file diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/TemplateRenderer.h index 413108fb6..aecb899e3 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.h +++ b/engine/extensions/remote-engine/TemplateRenderer.h @@ -1,28 +1,32 @@ #pragma once -#include -#include +// clang-format off #include #include +// clang-format on +#include + +#include #include "json/json.h" #include "trantor/utils/Logger.h" class TemplateRenderer { -public: - TemplateRenderer(); - ~TemplateRenderer() = default; + public: + TemplateRenderer(); + ~TemplateRenderer() = default; + + // Convert Json::Value to nlohmann::json + static nlohmann::json convertJsonValue(const Json::Value& input); - // Convert Json::Value to nlohmann::json - static nlohmann::json convertJsonValue(const Json::Value& input); - - // Convert nlohmann::json to Json::Value - static Json::Value convertNlohmannJson(const nlohmann::json& input); + // Convert nlohmann::json to Json::Value + static Json::Value convertNlohmannJson(const nlohmann::json& input); - // Render template with data - std::string render(const std::string& tmpl, const Json::Value& data); + // Render template with data + std::string render(const std::string& tmpl, const Json::Value& data); - // Load template from file and render - std::string renderFile(const std::string& template_path, const Json::Value& data); + // Load template from file and render + std::string renderFile(const std::string& template_path, + const Json::Value& data); -private: - inja::Environment env_; + private: + inja::Environment env_; }; \ No newline at end of file From a916ec865c9189b902ee32a34b951453dc3ce676 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 12 Nov 2024 00:19:22 +0700 Subject: [PATCH 04/33] Fix: CI build windows --- .../extensions/remote-engine/TemplateRenderer.cc | 8 ++++++-- .../extensions/remote-engine/TemplateRenderer.h | 16 +++++++++++----- engine/extensions/remote-engine/remote_engine.cc | 5 +++-- engine/extensions/remote-engine/remote_engine.h | 6 +++++- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 67b7e13d3..12c80756c 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -1,7 +1,10 @@ +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#endif #include "TemplateRenderer.h" #include #include - +namespace remote_engine { TemplateRenderer::TemplateRenderer() { // Configure Inja environment env_.set_trim_blocks(true); @@ -126,4 +129,5 @@ std::string TemplateRenderer::renderFile(const std::string& template_path, throw std::runtime_error(std::string("Template file rendering failed: ") + e.what()); } -} \ No newline at end of file +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/TemplateRenderer.h index aecb899e3..7cd08a94e 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.h +++ b/engine/extensions/remote-engine/TemplateRenderer.h @@ -1,14 +1,18 @@ #pragma once -// clang-format off -#include -#include -// clang-format on #include #include #include "json/json.h" #include "trantor/utils/Logger.h" +// clang-format off +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#endif +#include +#include +// clang-format on +namespace remote_engine { class TemplateRenderer { public: TemplateRenderer(); @@ -29,4 +33,6 @@ class TemplateRenderer { private: inja::Environment env_; -}; \ No newline at end of file +}; + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index bab63bb75..12cb969b7 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -3,7 +3,7 @@ #include #include #include - +namespace remote_engine { constexpr const int k200OK = 200; constexpr const int k400BadRequest = 400; constexpr const int k409Conflict = 409; @@ -422,4 +422,5 @@ extern "C" { EngineI* get_engine() { return new RemoteEngine(); } -} \ No newline at end of file +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index e55107286..ac77e8015 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -11,6 +11,8 @@ #include "extensions/remote-engine/TemplateRenderer.h" #include "utils/file_logger.h" // Helper for CURL response + +namespace remote_engine { struct CurlResponse { std::string body; bool error{false}; @@ -81,4 +83,6 @@ class RemoteEngine : public EngineI { bool IsSupported(const std::string& feature) override; bool SetFileLogger(int max_log_lines, const std::string& log_path) override; void SetLogLevel(trantor::Logger::LogLevel logLevel) override; -}; \ No newline at end of file +}; + +} // namespace remote_engine \ No newline at end of file From a9c0d8bd01630f56692d33b315e1f08f92332e9e Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 12 Nov 2024 08:51:23 +0700 Subject: [PATCH 05/33] Fix: CI build windows --- engine/extensions/remote-engine/TemplateRenderer.cc | 2 ++ engine/extensions/remote-engine/TemplateRenderer.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 12c80756c..3c3a0ea00 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -1,5 +1,7 @@ #if defined(_WIN32) || defined(_WIN64) #define NOMINMAX +#undef min +#undef max #endif #include "TemplateRenderer.h" #include diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/TemplateRenderer.h index 7cd08a94e..7f2f7fd88 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.h +++ b/engine/extensions/remote-engine/TemplateRenderer.h @@ -8,6 +8,8 @@ // clang-format off #if defined(_WIN32) || defined(_WIN64) #define NOMINMAX +#undef min +#undef max #endif #include #include From 9d1a9d8a229630a7e85f642e33c588e50ef5bdc8 Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Tue, 12 Nov 2024 17:51:14 +0700 Subject: [PATCH 06/33] feat: new db schema for model and template for engine --- engine/controllers/models.cc | 13 ++- engine/database/engines.cc | 19 ++++ engine/database/engines.h | 32 ++++++ engine/database/models.cc | 106 +++++++++++++----- engine/database/models.h | 17 ++- .../remote-engine/TemplateRenderer.cc | 2 + .../remote-engine/TemplateRenderer.h | 2 + engine/test/components/test_models_db.cc | 20 ++-- 8 files changed, 175 insertions(+), 36 deletions(-) create mode 100644 engine/database/engines.cc create mode 100644 engine/database/engines.h diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index c205e85df..7552a0c47 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -334,8 +334,17 @@ void Models::ImportModel( // Use relative path for model_yaml_path. In case of import, we use absolute path for model auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); - cortex::db::ModelEntry model_entry{modelHandle, "local", "imported", - yaml_rel_path.string(), modelHandle}; + cortex::db::ModelEntry model_entry { + modelHandle, + "local", + "imported", + cortex::db::ModelStatus::Downloaded, + "", + "", + "", + yaml_rel_path.string(), + modelHandle + }; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); diff --git a/engine/database/engines.cc b/engine/database/engines.cc new file mode 100644 index 000000000..ec8049c0a --- /dev/null +++ b/engine/database/engines.cc @@ -0,0 +1,19 @@ +#include "engines.h" +#include "database.h" + +namespace cortex::db { + +Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { + db_.exec( + "CREATE TABLE IF NOT EXISTS engines (" + "engine_id TEXT PRIMARY KEY," + "type TEXT," + "api_key TEXT," + "url TEXT," + "version TEXT," + "variant TEXT," + "status TEXT," + "metadata TEXT);"); +} +Engines::~Engines() {} +} \ No newline at end of file diff --git a/engine/database/engines.h b/engine/database/engines.h new file mode 100644 index 000000000..737ba0fb0 --- /dev/null +++ b/engine/database/engines.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include +#include +#include "utils/result.hpp" + +namespace cortex::db { + +struct EngineEntry { + std::string engine; +}; + +class Engines { + + private: + SQLite::Database& db_; + + bool IsUnique(const std::vector& entries, + const std::string& model_id, + const std::string& model_alias) const; + + cpp::result, std::string> LoadModelListNoLock() const; + + public: + Engines(); + Engines(SQLite::Database& db); + ~Engines(); +}; + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.cc b/engine/database/models.cc index 67ecb9723..0c78fc7e0 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -12,6 +12,10 @@ Models::Models() : db_(cortex::db::Database::GetInstance().db()) { db_.exec( "CREATE TABLE IF NOT EXISTS models (" "model_id TEXT PRIMARY KEY," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT," "author_repo_id TEXT," "branch_name TEXT," "path_to_model_yaml TEXT," @@ -22,14 +26,40 @@ Models::Models(SQLite::Database& db) : db_(db) { db_.exec( "CREATE TABLE IF NOT EXISTS models (" "model_id TEXT PRIMARY KEY," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT," "author_repo_id TEXT," "branch_name TEXT," "path_to_model_yaml TEXT," - "model_alias TEXT UNIQUE);"); + "model_alias TEXT);"); } - Models::~Models() {} +std::string Models::StatusToString(ModelStatus status) const { + switch (status) { + case ModelStatus::Remote: + return "remote"; + case ModelStatus::Downloaded: + return "downloaded"; + case ModelStatus::Undownloaded: + return "undownloaded"; + } + return "unknown"; +} + +ModelStatus Models::StringToStatus(const std::string& status_str) const { + if (status_str == "remote") { + return ModelStatus::Remote; + } else if (status_str == "downloaded") { + return ModelStatus::Downloaded; + } else if (status_str == "undownloaded") { + return ModelStatus::Undownloaded; + } + throw std::invalid_argument("Invalid status string"); +} + cpp::result, std::string> Models::LoadModelList() const { try { @@ -57,16 +87,21 @@ cpp::result, std::string> Models::LoadModelListNoLock() try { std::vector entries; SQLite::Statement query(db_, - "SELECT model_id, author_repo_id, branch_name, " + "SELECT model_id, model_format, model_source, " + "status, engine, author_repo_id, branch_name, " "path_to_model_yaml, model_alias FROM models"); while (query.executeStep()) { ModelEntry entry; entry.model = query.getColumn(0).getString(); - entry.author_repo_id = query.getColumn(1).getString(); - entry.branch_name = query.getColumn(2).getString(); - entry.path_to_model_yaml = query.getColumn(3).getString(); - entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(1).getString(); + entry.model_source = query.getColumn(2).getString(); + entry.status = StringToStatus(query.getColumn(3).getString()); + entry.engine = query.getColumn(4).getString(); + entry.author_repo_id = query.getColumn(5).getString(); + entry.branch_name = query.getColumn(6).getString(); + entry.path_to_model_yaml = query.getColumn(7).getString(); + entry.model_alias = query.getColumn(8).getString(); entries.push_back(entry); } return entries; @@ -140,7 +175,8 @@ cpp::result Models::GetModelInfo( const std::string& identifier) const { try { SQLite::Statement query(db_, - "SELECT model_id, author_repo_id, branch_name, " + "SELECT model_id, model_format, model_source, " + "status, engine, author_repo_id, branch_name, " "path_to_model_yaml, model_alias FROM models " "WHERE model_id = ? OR model_alias = ?"); @@ -149,10 +185,14 @@ cpp::result Models::GetModelInfo( if (query.executeStep()) { ModelEntry entry; entry.model = query.getColumn(0).getString(); - entry.author_repo_id = query.getColumn(1).getString(); - entry.branch_name = query.getColumn(2).getString(); - entry.path_to_model_yaml = query.getColumn(3).getString(); - entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(1).getString(); + entry.model_source = query.getColumn(2).getString(); + entry.status = StringToStatus(query.getColumn(3).getString()); + entry.engine = query.getColumn(4).getString(); + entry.author_repo_id = query.getColumn(5).getString(); + entry.branch_name = query.getColumn(6).getString(); + entry.path_to_model_yaml = query.getColumn(7).getString(); + entry.model_alias = query.getColumn(8).getString(); return entry; } else { return cpp::fail("Model not found: " + identifier); @@ -164,6 +204,10 @@ cpp::result Models::GetModelInfo( void Models::PrintModelInfo(const ModelEntry& entry) const { LOG_INFO << "Model ID: " << entry.model; + LOG_INFO << "Model Format: " << entry.model_format; + LOG_INFO << "Model Source: " << entry.model_source; + LOG_INFO << "Status: " << StatusToString(entry.status); + LOG_INFO << "Engine: " << entry.engine; LOG_INFO << "Author/Repo ID: " << entry.author_repo_id; LOG_INFO << "Branch Name: " << entry.branch_name; LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; @@ -188,14 +232,18 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, SQLite::Statement insert( db_, - "INSERT INTO models (model_id, author_repo_id, " - "branch_name, path_to_model_yaml, model_alias) VALUES (?, ?, " - "?, ?, ?)"); + "INSERT INTO models (model_id, model_format, model_source, status, " + "engine, author_repo_id, branch_name, path_to_model_yaml, model_alias) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.bind(1, new_entry.model); - insert.bind(2, new_entry.author_repo_id); - insert.bind(3, new_entry.branch_name); - insert.bind(4, new_entry.path_to_model_yaml); - insert.bind(5, new_entry.model_alias); + insert.bind(2, new_entry.model_format); + insert.bind(3, new_entry.model_source); + insert.bind(4, StatusToString(new_entry.status)); + insert.bind(5, new_entry.engine); + insert.bind(6, new_entry.author_repo_id); + insert.bind(7, new_entry.branch_name); + insert.bind(8, new_entry.path_to_model_yaml); + insert.bind(9, new_entry.model_alias); insert.exec(); return true; @@ -215,14 +263,19 @@ cpp::result Models::UpdateModelEntry( try { SQLite::Statement upd(db_, "UPDATE models " - "SET author_repo_id = ?, branch_name = ?, " + "SET model_format = ?, model_source = ?, status = ?, " + "engine = ?, author_repo_id = ?, branch_name = ?, " "path_to_model_yaml = ? " "WHERE model_id = ? OR model_alias = ?"); - upd.bind(1, updated_entry.author_repo_id); - upd.bind(2, updated_entry.branch_name); - upd.bind(3, updated_entry.path_to_model_yaml); - upd.bind(4, identifier); - upd.bind(5, identifier); + upd.bind(1, updated_entry.model_format); + upd.bind(2, updated_entry.model_source); + upd.bind(3, StatusToString(updated_entry.status)); + upd.bind(4, updated_entry.engine); + upd.bind(5, updated_entry.author_repo_id); + upd.bind(6, updated_entry.branch_name); + upd.bind(7, updated_entry.path_to_model_yaml); + upd.bind(8, identifier); + upd.bind(9, identifier); return upd.exec() == 1; } catch (const std::exception& e) { return cpp::fail(e.what()); @@ -305,4 +358,5 @@ bool Models::HasModel(const std::string& identifier) const { return false; } } -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.h b/engine/database/models.h index ebb006b28..705170f6f 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -7,8 +7,19 @@ #include "utils/result.hpp" namespace cortex::db { + +enum class ModelStatus { + Remote, + Downloaded, + Undownloaded +}; + struct ModelEntry { std::string model; + std::string model_format; + std::string model_source; + ModelStatus status; + std::string engine; std::string author_repo_id; std::string branch_name; std::string path_to_model_yaml; @@ -26,6 +37,9 @@ class Models { cpp::result, std::string> LoadModelListNoLock() const; + std::string StatusToString(ModelStatus status) const; + ModelStatus StringToStatus(const std::string& status_str) const; + public: static const std::string kModelListPath; cpp::result, std::string> LoadModelList() const; @@ -50,4 +64,5 @@ class Models { const std::string& identifier) const; bool HasModel(const std::string& identifier) const; }; -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 12c80756c..3c3a0ea00 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -1,5 +1,7 @@ #if defined(_WIN32) || defined(_WIN64) #define NOMINMAX +#undef min +#undef max #endif #include "TemplateRenderer.h" #include diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/TemplateRenderer.h index 7cd08a94e..7f2f7fd88 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.h +++ b/engine/extensions/remote-engine/TemplateRenderer.h @@ -8,6 +8,8 @@ // clang-format off #if defined(_WIN32) || defined(_WIN64) #define NOMINMAX +#undef min +#undef max #endif #include #include diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index ef54fe7e0..a743c9221 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -6,6 +6,7 @@ namespace cortex::db { namespace { constexpr const auto kTestDb = "./test.db"; } + class ModelsTestSuite : public ::testing::Test { public: ModelsTestSuite() @@ -21,9 +22,9 @@ class ModelsTestSuite : public ::testing::Test { SQLite::Database db_; cortex::db::Models model_list_; - const cortex::db::ModelEntry kTestModel{"test_model_id", "test_author", - "main", "/path/to/model.yaml", - "test_alias"}; + const cortex::db::ModelEntry kTestModel{ + "test_model_id", "test_format", "test_source", cortex::db::ModelStatus::Downloaded, "test_engine", + "test_author", "main", "/path/to/model.yaml", "test_alias"}; }; TEST_F(ModelsTestSuite, TestAddModelEntry) { @@ -33,8 +34,12 @@ TEST_F(ModelsTestSuite, TestAddModelEntry) { EXPECT_TRUE(retrieved_model); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); + EXPECT_EQ(retrieved_model.value().model_format, kTestModel.model_format); + EXPECT_EQ(retrieved_model.value().model_source, kTestModel.model_source); + EXPECT_EQ(retrieved_model.value().status, kTestModel.status); + EXPECT_EQ(retrieved_model.value().engine, kTestModel.engine); - // // Clean up + // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } @@ -59,14 +64,14 @@ TEST_F(ModelsTestSuite, TestUpdateModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); cortex::db::ModelEntry updated_model = kTestModel; + updated_model.status = cortex::db::ModelStatus::Downloaded; EXPECT_TRUE( model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); EXPECT_TRUE(retrieved_model); - EXPECT_TRUE( - model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); + EXPECT_EQ(retrieved_model.value().status, updated_model.status); // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); @@ -162,4 +167,5 @@ TEST_F(ModelsTestSuite, TestHasModel) { // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file From c43591864426cbef13f8d733c3d039705f88dea7 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Tue, 12 Nov 2024 23:35:13 +0700 Subject: [PATCH 07/33] Add remote model --- engine/config/model_config.h | 81 ++++++++++++++++++++++ engine/controllers/models.cc | 106 ++++++++++++++++++++++++++--- engine/controllers/models.h | 5 ++ engine/services/engine_service.h | 1 + engine/utils/remote_models_utils.h | 102 +++++++++++++++++++++++++++ 5 files changed, 284 insertions(+), 11 deletions(-) create mode 100644 engine/utils/remote_models_utils.h diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 044fd8dd3..d17c52b4a 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -2,13 +2,94 @@ #include #include +#include #include #include #include +#include #include #include #include "utils/format_utils.h" +#include "utils/remote_models_utils.h" +#include "yaml-cpp/yaml.h" namespace config { + +struct RemoteModelConfig { + std::string model; + std::string api_key_template; + std::string engine; + std::string version; + Json::Value inference_params; + Json::Value TransformReq; + Json::Value TransformResp; + Json::Value metadata; + void LoadFromJson(const Json::Value& json) { + if (!json.isObject()) { + throw std::runtime_error("Input JSON must be an object"); + } + + // Load basic string fields + model = json.get("model", "").asString(); + api_key_template = json.get("api_key_template", "").asString(); + engine = json.get("engine", "").asString(); + version = json.get("version", "").asString(); + + // Load JSON object fields directly + inference_params = + json.get("inference_params", Json::Value(Json::objectValue)); + TransformReq = json.get("TransformReq", Json::Value(Json::objectValue)); + TransformResp = json.get("TransformResp", Json::Value(Json::objectValue)); + metadata = json.get("metadata", Json::Value(Json::objectValue)); + } + + void SaveToYamlFile(const std::string& filepath) const { + YAML::Node root; + + // Convert basic fields + root["model"] = model; + root["api_key_template"] = api_key_template; + root["engine"] = engine; + root["version"] = version; + + // Convert Json::Value to YAML::Node using utility function + root["inference_params"] = + remote_models_utils::jsonToYaml(inference_params); + root["TransformReq"] = remote_models_utils::jsonToYaml(TransformReq); + root["TransformResp"] = remote_models_utils::jsonToYaml(TransformResp); + root["metadata"] = remote_models_utils::jsonToYaml(metadata); + + // Save to file + std::ofstream fout(filepath); + if (!fout.is_open()) { + throw std::runtime_error("Failed to open file for writing: " + filepath); + } + fout << root; + } + + void LoadFromYamlFile(const std::string& filepath) { + YAML::Node root; + try { + root = YAML::LoadFile(filepath); + } catch (const YAML::Exception& e) { + throw std::runtime_error("Failed to parse YAML file: " + + std::string(e.what())); + } + + // Load basic fields + model = root["model"].as(""); + api_key_template = root["api_key_template"].as(""); + engine = root["engine"].as(""); + version = root["version"] ? root["version"].as() : ""; + + // Load complex fields using utility function + inference_params = + remote_models_utils::yamlToJson(root["inference_params"]); + TransformReq = remote_models_utils::yamlToJson(root["TransformReq"]); + TransformResp = remote_models_utils::yamlToJson(root["TransformResp"]); + metadata = remote_models_utils::yamlToJson(root["metadata"]); + } +}; + struct ModelConfig { std::string name; std::string model; diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 7552a0c47..e4b3444bb 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -334,17 +334,10 @@ void Models::ImportModel( // Use relative path for model_yaml_path. In case of import, we use absolute path for model auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); - cortex::db::ModelEntry model_entry { - modelHandle, - "local", - "imported", - cortex::db::ModelStatus::Downloaded, - "", - "", - "", - yaml_rel_path.string(), - modelHandle - }; + cortex::db::ModelEntry model_entry{ + modelHandle, "local", "imported", cortex::db::ModelStatus::Downloaded, + "", "", "", yaml_rel_path.string(), + modelHandle}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); @@ -545,3 +538,94 @@ void Models::GetModelStatus( callback(resp); } } + +void Models::AddRemoteModel( + const HttpRequestPtr& req, + std::function&& callback) const { + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + if (!http_util::HasFieldInReq(req, callback, "model") || + !http_util::HasFieldInReq(req, callback, "engine")) { + return; + } + + auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); + /* To do: uncomment when remote engine is ready + + auto engine_validate = engine_service_->IsEngineReady(engine_name); + if (engine_validate.has_error()) { + Json::Value ret; + ret["message"] = engine_validate.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + if (!engine_validate.value()) { + Json::Value ret; + ret["message"] = "Engine is not ready! Please install first!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + */ + config::RemoteModelConfig model_config; + model_config.LoadFromJson(*(req->getJsonObject())); + cortex::db::Models modellist_utils_obj; + std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / + std::filesystem::path("remote") / + std::filesystem::path(model_handle + ".yml")) + .string(); + try { + // Use relative path for model_yaml_path. In case of import, we use absolute path for model + auto yaml_rel_path = + fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); + // TODO: remove hardcode "openai" when engine is finish + cortex::db::ModelEntry model_entry{ + model_handle, "remote", "imported", cortex::db::ModelStatus::Remote, + "openai", "", "", yaml_rel_path.string(), + model_handle}; + std::filesystem::create_directories( + std::filesystem::path(model_yaml_path).parent_path()); + if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + model_config.SaveToYamlFile(model_yaml_path); + std::string success_message = "Model is imported successfully!"; + LOG_INFO << success_message; + Json::Value ret; + ret["result"] = "OK"; + ret["modelHandle"] = model_handle; + ret["message"] = success_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + + } else { + std::string error_message = "Fail to import model, model_id '" + + model_handle + "' already exists!"; + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Import failed!"; + ret["modelHandle"] = model_handle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } + } catch (const std::exception& e) { + std::string error_message = + "Error while adding Remote model with model_id '" + model_handle + + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Add failed!"; + ret["modelHandle"] = model_handle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} \ No newline at end of file diff --git a/engine/controllers/models.h b/engine/controllers/models.h index da6caf024..24adffd1c 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -21,6 +21,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::StartModel, "/start", Options, Post); METHOD_ADD(Models::StopModel, "/stop", Options, Post); METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); + METHOD_ADD(Models::AddRemoteModel, "/add", Options, Post); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Options, Post); ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); @@ -32,6 +33,7 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Options, Post); ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); + ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); METHOD_LIST_END explicit Models(std::shared_ptr model_service, @@ -56,6 +58,9 @@ class Models : public drogon::HttpController { void ImportModel( const HttpRequestPtr& req, std::function&& callback) const; + void AddRemoteModel( + const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id); diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 4e58fccfd..d1a9807e8 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -13,6 +13,7 @@ #include "utils/github_release_utils.h" #include "utils/result.hpp" #include "utils/system_info_utils.h" +#include "extensions/remote-engine/remote_engine.h" // TODO: namh think of the other name struct DefaultEngineVariant { diff --git a/engine/utils/remote_models_utils.h b/engine/utils/remote_models_utils.h new file mode 100644 index 000000000..6c13f62c1 --- /dev/null +++ b/engine/utils/remote_models_utils.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include + +namespace remote_models_utils { + +inline Json::Value yamlToJson(const YAML::Node& node) { + Json::Value result; + + switch (node.Type()) { + case YAML::NodeType::Null: + return Json::Value(); + case YAML::NodeType::Scalar: { + // For scalar types, we'll first try to parse as string + std::string str_val = node.as(); + + // Try to parse as boolean + if (str_val == "true" || str_val == "True" || str_val == "TRUE") + return Json::Value(true); + if (str_val == "false" || str_val == "False" || str_val == "FALSE") + return Json::Value(false); + + // Try to parse as number + try { + // Check if it's an integer + size_t pos; + long long int_val = std::stoll(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(static_cast(int_val)); + } + + // Check if it's a float + double float_val = std::stod(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(float_val); + } + } catch (...) { + // If parsing as number fails, use as string + } + + // Default to string if no other type matches + return Json::Value(str_val); + } + case YAML::NodeType::Sequence: { + result = Json::Value(Json::arrayValue); + for (const auto& elem : node) { + result.append(yamlToJson(elem)); + } + return result; + } + case YAML::NodeType::Map: { + result = Json::Value(Json::objectValue); + for (const auto& it : node) { + std::string key = it.first.as(); + result[key] = yamlToJson(it.second); + } + return result; + } + default: + return Json::Value(); + } +} + +inline YAML::Node jsonToYaml(const Json::Value& json) { + YAML::Node result; + + switch (json.type()) { + case Json::nullValue: + result = YAML::Node(YAML::NodeType::Null); + break; + case Json::intValue: + result = json.asInt64(); + break; + case Json::uintValue: + result = json.asUInt64(); + break; + case Json::realValue: + result = json.asDouble(); + break; + case Json::stringValue: + result = json.asString(); + break; + case Json::booleanValue: + result = json.asBool(); + break; + case Json::arrayValue: + result = YAML::Node(YAML::NodeType::Sequence); + for (const auto& elem : json) + result.push_back(jsonToYaml(elem)); + break; + case Json::objectValue: + result = YAML::Node(YAML::NodeType::Map); + for (const auto& key : json.getMemberNames()) + result[key] = jsonToYaml(json[key]); + break; + } + return result; +} + +} // namespace utils \ No newline at end of file From 28d31060143e254b289e34c2e9b220959a6bc428 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 13 Nov 2024 14:31:52 +0700 Subject: [PATCH 08/33] Add Get, List, Update support for remote models --- engine/config/model_config.h | 52 ++++++++++++++++++++----- engine/controllers/models.cc | 73 ++++++++++++++++++++++++++++-------- 2 files changed, 101 insertions(+), 24 deletions(-) diff --git a/engine/config/model_config.h b/engine/config/model_config.h index d17c52b4a..734a34558 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -19,6 +20,9 @@ struct RemoteModelConfig { std::string api_key_template; std::string engine; std::string version; + std::size_t created; + std::string object = "model"; + std::string owned_by = ""; Json::Value inference_params; Json::Value TransformReq; Json::Value TransformResp; @@ -29,19 +33,43 @@ struct RemoteModelConfig { } // Load basic string fields - model = json.get("model", "").asString(); - api_key_template = json.get("api_key_template", "").asString(); - engine = json.get("engine", "").asString(); - version = json.get("version", "").asString(); + model = json.get("model", model).asString(); + api_key_template = + json.get("api_key_template", api_key_template).asString(); + engine = json.get("engine", engine).asString(); + version = json.get("version", version).asString(); + created = json.get("created", created).asUInt64(); + object = json.get("object", object).asString(); + owned_by = json.get("owned_by", owned_by).asString(); // Load JSON object fields directly - inference_params = - json.get("inference_params", Json::Value(Json::objectValue)); - TransformReq = json.get("TransformReq", Json::Value(Json::objectValue)); - TransformResp = json.get("TransformResp", Json::Value(Json::objectValue)); - metadata = json.get("metadata", Json::Value(Json::objectValue)); + inference_params = json.get("inference_params", inference_params); + TransformReq = json.get("TransformReq", TransformReq); + TransformResp = json.get("TransformResp", TransformResp); + metadata = json.get("metadata", metadata); } + Json::Value ToJson() const { + Json::Value json; + + // Add basic string fields + json["model"] = model; + json["api_key_template"] = api_key_template; + json["engine"] = engine; + json["version"] = version; + json["created"] = static_cast(created); + json["object"] = object; + json["owned_by"] = owned_by; + + // Add JSON object fields directly + json["inference_params"] = inference_params; + json["TransformReq"] = TransformReq; + json["TransformResp"] = TransformResp; + json["metadata"] = metadata; + + return json; + }; + void SaveToYamlFile(const std::string& filepath) const { YAML::Node root; @@ -50,6 +78,9 @@ struct RemoteModelConfig { root["api_key_template"] = api_key_template; root["engine"] = engine; root["version"] = version; + root["object"] = object; + root["owned_by"] = owned_by; + root["created"] = std::time(nullptr); // Convert Json::Value to YAML::Node using utility function root["inference_params"] = @@ -80,6 +111,9 @@ struct RemoteModelConfig { api_key_template = root["api_key_template"].as(""); engine = root["engine"].as(""); version = root["version"] ? root["version"].as() : ""; + created = root["created"] ? root["created"].as() : 0; + object = root["object"] ? root["object"].as() : "model"; + owned_by = root["owned_by"] ? root["owned_by"].as() : ""; // Load complex fields using utility function inference_params = diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index e4b3444bb..20075b206 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -7,6 +7,7 @@ #include "models.h" #include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" +#include "utils/engine_constants.h" #include "utils/file_manager_utils.h" #include "utils/http_util.h" #include "utils/logging_utils.h" @@ -168,11 +169,27 @@ void Models::ListModel( fs::path(model_entry.path_to_model_yaml)) .string()); auto model_config = yaml_handler.GetModelConfig(); - Json::Value obj = model_config.ToJson(); - obj["id"] = model_entry.model; - obj["model"] = model_entry.model; - data.append(std::move(obj)); - yaml_handler.Reset(); + + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + Json::Value obj = model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + data.append(std::move(obj)); + yaml_handler.Reset(); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string()); + Json::Value obj = remote_model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + data.append(std::move(obj)); + } + } catch (const std::exception& e) { LOG_ERROR << "Failed to load yaml file for model: " << model_entry.path_to_model_yaml << ", error: " << e.what(); @@ -218,17 +235,31 @@ void Models::GetModel(const HttpRequestPtr& req, callback(resp); return; } + yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) .string()); auto model_config = yaml_handler.GetModelConfig(); + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + ret = model_config.ToJson(); - ret = model_config.ToJson(); - - ret["id"] = model_config.model; - ret["object"] = "model"; - ret["result"] = "OK"; + ret["id"] = model_config.model; + ret["object"] = "model"; + ret["result"] = "OK"; + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + ret = remote_model_config.ToJson(); + ret["id"] = remote_model_config.model; + ret["object"] = "model"; + ret["result"] = "OK"; + } auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); callback(resp); @@ -279,11 +310,23 @@ void Models::UpdateModel(const HttpRequestPtr& req, fs::path(model_entry.value().path_to_model_yaml)); yaml_handler.ModelConfigFromFile(yaml_fp.string()); config::ModelConfig model_config = yaml_handler.GetModelConfig(); - model_config.FromJson(json_body); - yaml_handler.UpdateModelConfig(model_config); - yaml_handler.WriteYamlFile(yaml_fp.string()); - std::string message = "Successfully update model ID '" + model_id + - "': " + json_body.toStyledString(); + std::string message; + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + model_config.FromJson(json_body); + yaml_handler.UpdateModelConfig(model_config); + yaml_handler.WriteYamlFile(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile(yaml_fp.string()); + remote_model_config.LoadFromJson(json_body); + remote_model_config.SaveToYamlFile(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + } LOG_INFO << message; Json::Value ret; ret["result"] = "Updated successfully!"; From 6508c98f2121b9bbf3c2631dcf1d0d6ecd986f89 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Wed, 13 Nov 2024 14:48:50 +0700 Subject: [PATCH 09/33] change model_id to model in remote engine --- .../extensions/remote-engine/remote_engine.cc | 63 ++++++++++--------- .../extensions/remote-engine/remote_engine.h | 6 +- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 12cb969b7..afefe1f0c 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -38,9 +38,9 @@ RemoteEngine::~RemoteEngine() { } RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( - const std::string& model_id) { + const std::string& model) { std::shared_lock lock(models_mutex_); - auto it = models_.find(model_id); + auto it = models_.find(model); if (it != models_.end()) { return &it->second; } @@ -132,18 +132,18 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( return response; } -bool RemoteEngine::LoadModelConfig(const std::string& model_id, +bool RemoteEngine::LoadModelConfig(const std::string& model, const std::string& yaml_path, const std::string& api_key) { try { YAML::Node config = YAML::LoadFile(yaml_path); ModelConfig model_config; - model_config.model_id = model_id; + model_config.model = model; // Required fields if (!config["api_key_template"]) { - LOG_ERROR << "Missing required fields in config for model " << model_id; + LOG_ERROR << "Missing required fields in config for model " << model; return false; } @@ -157,12 +157,12 @@ bool RemoteEngine::LoadModelConfig(const std::string& model_id, if (config["TransformReq"]) { model_config.transform_req = config["TransformReq"]; } else { - LOG_WARN << "Missing TransformReq in config for model " << model_id; + LOG_WARN << "Missing TransformReq in config for model " << model; } if (config["TransformResp"]) { model_config.transform_resp = config["TransformResp"]; } else { - LOG_WARN << "Missing TransformResp in config for model " << model_id; + LOG_WARN << "Missing TransformResp in config for model " << model; } model_config.is_loaded = true; @@ -170,12 +170,12 @@ bool RemoteEngine::LoadModelConfig(const std::string& model_id, // Thread-safe update of models map { std::unique_lock lock(models_mutex_); - models_[model_id] = std::move(model_config); + models_[model] = std::move(model_config); } return true; } catch (const YAML::Exception& e) { - LOG_ERROR << "Failed to load config for model " << model_id << ": " + LOG_ERROR << "Failed to load config for model " << model << ": " << e.what(); return false; } @@ -222,19 +222,19 @@ void RemoteEngine::GetModels( void RemoteEngine::LoadModel( std::shared_ptr json_body, std::function&& callback) { - if (!json_body->isMember("model_id") || !json_body->isMember("model_path") || + if (!json_body->isMember("model") || !json_body->isMember("model_path") || !json_body->isMember("api_key")) { Json::Value error; - error["error"] = "Missing required fields: model_id or model_path"; + error["error"] = "Missing required fields: model or model_path"; callback(Json::Value(), std::move(error)); return; } - const std::string& model_id = (*json_body)["model_id"].asString(); + const std::string& model = (*json_body)["model"].asString(); const std::string& model_path = (*json_body)["model_path"].asString(); const std::string& api_key = (*json_body)["api_key"].asString(); - if (!LoadModelConfig(model_id, model_path, api_key)) { + if (!LoadModelConfig(model, model_path, api_key)) { Json::Value error; error["error"] = "Failed to load model configuration"; callback(Json::Value(), std::move(error)); @@ -252,18 +252,18 @@ void RemoteEngine::LoadModel( void RemoteEngine::UnloadModel( std::shared_ptr json_body, std::function&& callback) { - if (!json_body->isMember("model_id")) { + if (!json_body->isMember("model")) { Json::Value error; - error["error"] = "Missing required field: model_id"; + error["error"] = "Missing required field: model"; callback(Json::Value(), std::move(error)); return; } - const std::string& model_id = (*json_body)["model_id"].asString(); + const std::string& model = (*json_body)["model"].asString(); { std::unique_lock lock(models_mutex_); - models_.erase(model_id); + models_.erase(model); } Json::Value response; @@ -286,8 +286,8 @@ void RemoteEngine::HandleChatCompletion( return; } - const std::string& model_id = (*json_body)["model"].asString(); - auto* model_config = GetModelConfig(model_id); + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); if (!model_config || !model_config->is_loaded) { Json::Value status; @@ -296,7 +296,7 @@ void RemoteEngine::HandleChatCompletion( status["is_stream"] = false; status["status_code"] = k400BadRequest; Json::Value error; - error["error"] = "Model not found or not loaded: " + model_id; + error["error"] = "Model not found or not loaded: " + model; callback(std::move(status), std::move(error)); return; } @@ -351,29 +351,34 @@ void RemoteEngine::HandleChatCompletion( void RemoteEngine::GetModelStatus( std::shared_ptr json_body, std::function&& callback) { - if (!json_body->isMember("model_id")) { + if (!json_body->isMember("model")) { Json::Value error; - error["error"] = "Missing required field: model_id"; + error["error"] = "Missing required field: model"; callback(Json::Value(), std::move(error)); return; } - const std::string& model_id = (*json_body)["model_id"].asString(); - auto* model_config = GetModelConfig(model_id); + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); if (!model_config) { Json::Value error; - error["error"] = "Model not found: " + model_id; + error["error"] = "Model not found: " + model; callback(Json::Value(), std::move(error)); return; } Json::Value response; - response["model_id"] = model_id; - response["is_loaded"] = model_config->is_loaded; - response["url"] = model_config->url; + response["model"] = model; + response["model_loaded"] = model_config->is_loaded; + response["model_data"] = model_config->url; - callback(std::move(response), Json::Value()); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); } // Implement remaining virtual functions diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index ac77e8015..2ae613b8d 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -23,7 +23,7 @@ class RemoteEngine : public EngineI { private: // Model configuration struct ModelConfig { - std::string model_id; + std::string model; std::string api_key; std::string url; YAML::Node transform_req; @@ -46,10 +46,10 @@ class RemoteEngine : public EngineI { CurlResponse MakeGetModelsRequest(); // Internal model management - bool LoadModelConfig(const std::string& model_id, + bool LoadModelConfig(const std::string& model, const std::string& yaml_path, const std::string& api_key); - ModelConfig* GetModelConfig(const std::string& model_id); + ModelConfig* GetModelConfig(const std::string& model); public: RemoteEngine(); From 7b295f8cac752ae0e4a5360a194e06d1839a428f Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Wed, 13 Nov 2024 15:04:52 +0700 Subject: [PATCH 10/33] fix: mac compatibility --- engine/config/model_config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 734a34558..1ff127329 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -38,7 +38,7 @@ struct RemoteModelConfig { json.get("api_key_template", api_key_template).asString(); engine = json.get("engine", engine).asString(); version = json.get("version", version).asString(); - created = json.get("created", created).asUInt64(); + created = json.get("created", static_cast(created)).asUInt64(); object = json.get("object", object).asString(); owned_by = json.get("owned_by", owned_by).asString(); From d921869cb636e338d262b5ab0c32a3df4c46775e Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Wed, 13 Nov 2024 16:58:45 +0700 Subject: [PATCH 11/33] chore: some refactors before making big changes --- engine/services/engine_service.cc | 79 ++++++++++++------------------- 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 0120def27..08cd8b122 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -228,27 +228,20 @@ cpp::result EngineService::UninstallEngineVariant( cpp::result EngineService::DownloadEngineV2( const std::string& engine, const std::string& version, const std::optional variant_name) { - auto normalized_version = version == "latest" - ? "latest" - : string_utils::RemoveSubstring(version, "v"); + auto normalized_version = version == "latest" ? "latest" : string_utils::RemoveSubstring(version, "v"); auto res = GetEngineVariants(engine, version); if (res.has_error()) { return cpp::fail("Failed to fetch engine releases: " + res.error()); } - if (res.value().empty()) { return cpp::fail("No release found for " + version); } std::optional selected_variant = std::nullopt; - if (variant_name.has_value()) { - auto latest_version_semantic = normalized_version == "latest" - ? res.value()[0].version - : normalized_version; - auto merged_variant_name = engine + "-" + latest_version_semantic + "-" + - variant_name.value() + ".tar.gz"; + auto latest_version_semantic = normalized_version == "latest" ? res.value()[0].version : normalized_version; + auto merged_variant_name = engine + "-" + latest_version_semantic + "-" + variant_name.value() + ".tar.gz"; for (const auto& asset : res.value()) { if (asset.name == merged_variant_name) { @@ -271,9 +264,10 @@ cpp::result EngineService::DownloadEngineV2( } } - if (selected_variant == std::nullopt) { + if (!selected_variant) { return cpp::fail("Failed to find a suitable variant for " + engine); } + if (IsEngineLoaded(engine)) { CTL_INF("Engine " << engine << " is already loaded, unloading it"); auto unload_res = UnloadEngine(engine); @@ -284,49 +278,34 @@ cpp::result EngineService::DownloadEngineV2( CTL_INF("Engine " << engine << " unloaded successfully"); } } - auto normalize_version = "v" + selected_variant->version; - - auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion( - selected_variant->name, engine, selected_variant->version); - - auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() / - engine / variant_folder_name.value() / - normalize_version; + auto normalize_version = "v" + selected_variant->version; + auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion(selected_variant->name, engine, selected_variant->version); + auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() / engine / variant_folder_name.value() / normalize_version; auto variant_path = variant_folder_path / selected_variant->name; + std::filesystem::create_directories(variant_folder_path); CLI_LOG("variant_folder_path: " + variant_folder_path.string()); - auto on_finished = [this, engine, selected_variant, variant_folder_path, - normalize_version](const DownloadTask& finishedTask) { - // try to unzip the downloaded file + + auto on_finished = [this, engine, selected_variant, variant_folder_path, normalize_version](const DownloadTask& finishedTask) { CLI_LOG("Engine zip path: " << finishedTask.items[0].localPath.string()); CLI_LOG("Version: " + normalize_version); auto extract_path = finishedTask.items[0].localPath.parent_path(); + archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), extract_path.string(), true); - archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), - extract_path.string(), true); - - auto variant = engine_matcher_utils::GetVariantFromNameAndVersion( - selected_variant->name, engine, normalize_version); + auto variant = engine_matcher_utils::GetVariantFromNameAndVersion(selected_variant->name, engine, normalize_version); + CLI_LOG("Extracted variant: " + variant.value()); - // set as default - auto res = - SetDefaultEngineVariant(engine, normalize_version, variant.value()); + auto res = SetDefaultEngineVariant(engine, normalize_version, variant.value()); if (res.has_error()) { CTL_ERR("Failed to set default engine variant: " << res.error()); } else { CTL_INF("Set default engine variant: " << res.value().variant); } - // remove other engines - auto engine_directories = file_manager_utils::GetEnginesContainerPath() / - engine / selected_variant->name; - - for (const auto& entry : std::filesystem::directory_iterator( - variant_folder_path.parent_path())) { - if (entry.is_directory() && - entry.path().filename() != normalize_version) { + for (const auto& entry : std::filesystem::directory_iterator(variant_folder_path.parent_path())) { + if (entry.is_directory() && entry.path().filename() != normalize_version) { try { std::filesystem::remove_all(entry.path()); } catch (const std::exception& e) { @@ -335,7 +314,6 @@ cpp::result EngineService::DownloadEngineV2( } } - // remove the downloaded file try { std::filesystem::remove(finishedTask.items[0].localPath); } catch (const std::exception& e) { @@ -344,19 +322,20 @@ cpp::result EngineService::DownloadEngineV2( CTL_INF("Finished!"); }; - auto downloadTask{ - DownloadTask{.id = engine, - .type = DownloadType::Engine, - .items = {DownloadItem{ - .id = engine, - .downloadUrl = selected_variant->browser_download_url, - .localPath = variant_path, - }}}}; + auto downloadTask = DownloadTask{ + .id = engine, + .type = DownloadType::Engine, + .items = {DownloadItem{ + .id = engine, + .downloadUrl = selected_variant->browser_download_url, + .localPath = variant_path, + }} + }; auto add_task_result = download_service_->AddTask(downloadTask, on_finished); - if (res.has_error()) { - return cpp::fail(res.error()); - } + if (add_task_result.has_error()) { + return cpp::fail(add_task_result.error()); + } return {}; } From 18f39003c1dd322de09b69cffdfff2d8a251e97f Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Wed, 13 Nov 2024 17:50:05 +0700 Subject: [PATCH 12/33] feat: db ops for engines --- engine/database/engines.cc | 98 +++++++++++++++++++++++++++---- engine/database/engines.h | 17 +++++- engine/services/engine_service.cc | 20 +++++++ engine/utils/result.hpp | 1 - 4 files changed, 121 insertions(+), 15 deletions(-) diff --git a/engine/database/engines.cc b/engine/database/engines.cc index ec8049c0a..f2a5c4d09 100644 --- a/engine/database/engines.cc +++ b/engine/database/engines.cc @@ -1,19 +1,93 @@ #include "engines.h" +#include +#include #include "database.h" - namespace cortex::db { +void CreateTable(SQLite::Database& db) { + db.exec( + "CREATE TABLE IF NOT EXISTS engines (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "engine_name TEXT," + "type TEXT," + "api_key TEXT," + "url TEXT," + "version TEXT," + "variant TEXT," + "status TEXT," + "metadata TEXT);"); +} + Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { - db_.exec( - "CREATE TABLE IF NOT EXISTS engines (" - "engine_id TEXT PRIMARY KEY," - "type TEXT," - "api_key TEXT," - "url TEXT," - "version TEXT," - "variant TEXT," - "status TEXT," - "metadata TEXT);"); + CreateTable(db_); } + +Engines::Engines(SQLite::Database& db) : db_(db) { + CreateTable(db_); +} + Engines::~Engines() {} -} \ No newline at end of file + + + +// Function to create a new engine and save it into the database +std::optional Engines::UpsertEngine(const std::string& engine_name, + const std::string& type, + const std::string& api_key, + const std::string& url, + const std::string& version, + const std::string& variant, + const std::string& status, + const std::string& metadata) { + try { + SQLite::Statement query(db_, + "INSERT INTO engines (engine_name, type, api_key, url, version, variant, status, metadata) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)"); + + query.bind(1, engine_name); + query.bind(2, type); + query.bind(3, api_key); + query.bind(4, url); + query.bind(5, version); + query.bind(6, variant); + query.bind(7, status); + query.bind(8, metadata); + + query.exec(); + return std::nullopt; + } catch (const std::exception& e) { + return std::string("Failed to create engine: ") + e.what(); + } +} + +std::optional Engines::GetEngine(int id, const std::string& engine_name) const { + try { + SQLite::Statement query(db_, + "SELECT engine_name FROM engines WHERE (id = ? OR engine_name = ?) AND status = 'Default' LIMIT 1"); + + query.bind(1, id); + query.bind(2, engine_name); + + if (query.executeStep()) { + return EngineEntry{query.getColumn(0).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::DeleteEngine(int id) { + try { + SQLite::Statement query(db_, + "DELETE FROM engines WHERE id = ?"); + + query.bind(1, id); + query.exec(); + return std::nullopt; + } catch (const std::exception& e) { + return std::string("Failed to delete engine: ") + e.what(); + } +} +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/engines.h b/engine/database/engines.h index 737ba0fb0..927b2679d 100644 --- a/engine/database/engines.h +++ b/engine/database/engines.h @@ -4,7 +4,7 @@ #include #include #include -#include "utils/result.hpp" +#include namespace cortex::db { @@ -21,12 +21,25 @@ class Engines { const std::string& model_id, const std::string& model_alias) const; - cpp::result, std::string> LoadModelListNoLock() const; + std::optional> LoadModelListNoLock() const; public: Engines(); Engines(SQLite::Database& db); ~Engines(); + + std::optional UpsertEngine(const std::string& engine_name, + const std::string& type, + const std::string& api_key, + const std::string& url, + const std::string& version, + const std::string& variant, + const std::string& status, + const std::string& metadata); + + std::optional GetEngine(int id, const std::string& engine_name) const; + + std::optional DeleteEngine(int id); }; } // namespace cortex::db \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 08cd8b122..a4ea44613 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -13,6 +13,7 @@ #include "utils/semantic_version_utils.h" #include "utils/system_info_utils.h" #include "utils/url_parser.h" +#include "database/engines.h" namespace { std::string GetSuitableCudaVersion(const std::string& engine, @@ -304,6 +305,25 @@ cpp::result EngineService::DownloadEngineV2( CTL_INF("Set default engine variant: " << res.value().variant); } + // Create engine entry in the database + cortex::db::Engines engines; + auto create_res = engines.UpsertEngine( + engine, // engine_name + "", // todo - luke + "", // todo - luke + "", // todo - luke + normalize_version, + variant.value(), + "Default", // todo - luke + "" // todo - luke + ); + + if (create_res.has_value()) { + CTL_ERR("Failed to create engine entry: " << create_res.value()); + } else { + CTL_INF("Engine entry created successfully"); + } + for (const auto& entry : std::filesystem::directory_iterator(variant_folder_path.parent_path())) { if (entry.is_directory() && entry.path().filename() != normalize_version) { try { diff --git a/engine/utils/result.hpp b/engine/utils/result.hpp index 96243f72e..7f7356b84 100644 --- a/engine/utils/result.hpp +++ b/engine/utils/result.hpp @@ -34,7 +34,6 @@ #include // std::size_t #include // std::enable_if, std::is_constructible, etc -#include // placement-new #include // std::address_of #include // std::reference_wrapper, std::invoke #include // std::in_place_t, std::forward From c5148e57291a88e9c94e24e5557955e30d5c9973 Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Wed, 13 Nov 2024 19:01:12 +0700 Subject: [PATCH 13/33] chore: small refactor before more changes --- engine/services/engine_service.h | 44 ++++++++++++-------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index a9f8c899c..916770a06 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -4,6 +4,9 @@ #include #include #include +#include +#include + #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" #include "services/download_service.h" @@ -13,11 +16,12 @@ #include "utils/github_release_utils.h" #include "utils/result.hpp" #include "utils/system_info_utils.h" - -#include "extensions/remote-engine/remote_engine.h" - #include "common/engine_servicei.h" +namespace system_info_utils { +struct SystemInfo; +} + struct EngineUpdateResult { std::string engine; std::string variant; @@ -34,10 +38,6 @@ struct EngineUpdateResult { } }; -namespace system_info_utils { -struct SystemInfo; -} - using EngineV = std::variant; class EngineService : public EngineServiceI { @@ -55,6 +55,14 @@ class EngineService : public EngineServiceI { }; std::unordered_map engines_{}; + std::shared_ptr download_service_; + + struct HardwareInfo { + std::unique_ptr sys_inf; + cortex::cpuid::CpuInfo cpu_inf; + std::string cuda_driver_version; + }; + HardwareInfo hw_inf_; public: const std::vector kSupportEngines = { @@ -67,21 +75,11 @@ class EngineService : public EngineServiceI { std::vector GetEngineInfoList() const; - /** - * Check if an engines is ready (have at least one variant installed) - */ cpp::result IsEngineReady(const std::string& engine) const; - cpp::result InstallEngineAsync( const std::string& engine, const std::string& version = "latest", const std::string& src = ""); - /** - * Handling install engine variant. - * - * If no version provided, choose `latest`. - * If no variant provided, automatically pick the best variant. - */ cpp::result InstallEngineAsyncV2( const std::string& engine, const std::string& version, const std::optional variant_name); @@ -114,7 +112,6 @@ class EngineService : public EngineServiceI { std::vector GetLoadedEngines(); cpp::result LoadEngine(const std::string& engine_name); - cpp::result UnloadEngine(const std::string& engine_name); cpp::result @@ -145,13 +142,4 @@ class EngineService : public EngineServiceI { cpp::result IsEngineVariantReady( const std::string& engine, const std::string& version, const std::string& variant); - - std::shared_ptr download_service_; - - struct HardwareInfo { - std::unique_ptr sys_inf; - cortex::cpuid::CpuInfo cpu_inf; - std::string cuda_driver_version; - }; - HardwareInfo hw_inf_; -}; +}; \ No newline at end of file From b2567addeef5573d8e334261194c90cdc17ba488 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Thu, 14 Nov 2024 00:51:50 +0700 Subject: [PATCH 14/33] Update engine --- engine/CMakeLists.txt | 5 +- engine/cli/CMakeLists.txt | 5 +- .../extensions/remote-engine/remote_engine.cc | 157 +++++++++++--- .../extensions/remote-engine/remote_engine.h | 61 +++++- engine/services/engine_service.cc | 107 +++++++--- engine/utils/remote_models_utils.h | 198 ++++++++++-------- 6 files changed, 378 insertions(+), 155 deletions(-) diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 92e07ec91..a6774c64e 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -139,6 +139,8 @@ file(APPEND "${CMAKE_CURRENT_BINARY_DIR}/cortex_openapi.h" add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/TemplateRenderer.cc ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -173,10 +175,11 @@ aux_source_directory(models MODEL_SRC) aux_source_directory(cortex-common CORTEX_COMMON) aux_source_directory(config CONFIG_SRC) aux_source_directory(database DB_SRC) +aux_source_directory(extensions EX_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${EX_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 758a51dc8..7c1a4e103 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -78,6 +78,8 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/TemplateRenderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ) @@ -114,11 +116,12 @@ aux_source_directory(../cortex-common CORTEX_COMMON) aux_source_directory(../config CONFIG_SRC) aux_source_directory(commands COMMANDS_SRC) aux_source_directory(../database DB_SRC) +aux_source_directory(../extensions EX_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/.. ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${COMMON_SRC} ${DB_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${COMMON_SRC} ${DB_SRC} ${EX_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index afefe1f0c..c59c3e865 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -10,6 +10,64 @@ constexpr const int k409Conflict = 409; constexpr const int k500InternalServerError = 500; constexpr const int kFileLoggerOption = 0; +CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( + const ModelConfig& config, const std::string& body, + std::function callback) { + + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = + config.transform_req["chat_completions"]["url"].as(); + + struct curl_slist* headers = nullptr; + if (!config.api_key.empty()) { + headers = curl_slist_append(headers, api_key_template_.c_str()); + } + + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, "Accept: text/event-stream"); + headers = curl_slist_append(headers, "Cache-Control: no-cache"); + headers = curl_slist_append(headers, "Connection: keep-alive"); + + StreamContext context{callback, ""}; + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context); + curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = 500; + + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + std::string ReplaceApiKeyPlaceholder(const std::string& templateStr, const std::string& apiKey) { const std::string placeholder = "{{api_key}}"; @@ -226,7 +284,12 @@ void RemoteEngine::LoadModel( !json_body->isMember("api_key")) { Json::Value error; error["error"] = "Missing required fields: model or model_path"; - callback(Json::Value(), std::move(error)); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); return; } @@ -237,7 +300,12 @@ void RemoteEngine::LoadModel( if (!LoadModelConfig(model, model_path, api_key)) { Json::Value error; error["error"] = "Failed to load model configuration"; - callback(Json::Value(), std::move(error)); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); return; } if (json_body->isMember("metadata")) { @@ -246,7 +314,12 @@ void RemoteEngine::LoadModel( Json::Value response; response["status"] = "Model loaded successfully"; - callback(Json::Value(), std::move(response)); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); } void RemoteEngine::UnloadModel( @@ -255,7 +328,12 @@ void RemoteEngine::UnloadModel( if (!json_body->isMember("model")) { Json::Value error; error["error"] = "Missing required field: model"; - callback(Json::Value(), std::move(error)); + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); return; } @@ -268,7 +346,12 @@ void RemoteEngine::UnloadModel( Json::Value response; response["status"] = "Model unloaded successfully"; - callback(std::move(response), Json::Value()); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); } void RemoteEngine::HandleChatCompletion( @@ -300,7 +383,8 @@ void RemoteEngine::HandleChatCompletion( callback(std::move(status), std::move(error)); return; } - + bool is_stream = + json_body->isMember("stream") && (*json_body)["stream"].asBool(); Json::FastWriter writer; std::string request_body = writer.write((*json_body)); std::cout << "template: " @@ -311,41 +395,45 @@ void RemoteEngine::HandleChatCompletion( model_config->transform_req["chat_completions"]["template"] .as(), (*json_body)); + if (is_stream) { + MakeStreamingChatCompletionRequest(*model_config, result, callback); + } else { - auto response = MakeChatCompletionRequest(*model_config, result); + auto response = MakeChatCompletionRequest(*model_config, result); + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } - if (response.error) { + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } Json::Value status; status["is_done"] = true; - status["has_error"] = true; + status["has_error"] = false; status["is_stream"] = false; - status["status_code"] = k400BadRequest; - Json::Value error; - error["error"] = response.error_message; - callback(std::move(status), std::move(error)); - return; - } + status["status_code"] = k200OK; - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value status; - status["is_done"] = true; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k500InternalServerError; - Json::Value error; - error["error"] = "Failed to parse response"; - callback(std::move(status), std::move(error)); - return; + callback(std::move(status), std::move(response_json)); } - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = k200OK; - - callback(std::move(status), std::move(response_json)); } void RemoteEngine::GetModelStatus( @@ -417,6 +505,7 @@ bool RemoteEngine::SetFileLogger(int max_log_lines, }); freopen(log_path.c_str(), "w", stderr); freopen(log_path.c_str(), "w", stdout); + return true; } void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 2ae613b8d..3b6226d95 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -13,6 +13,61 @@ // Helper for CURL response namespace remote_engine { + +struct StreamContext { + std::function callback; + std::string buffer; +}; + +static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* context = static_cast(userdata); + std::string chunk(ptr, size * nmemb); + + context->buffer += chunk; + + // Process complete lines + size_t pos; + while ((pos = context->buffer.find('\n')) != std::string::npos) { + std::string line = context->buffer.substr(0, pos); + context->buffer = context->buffer.substr(pos + 1); + + // Skip empty lines + if (line.empty() || line == "\r") + continue; + + // Remove "data: " prefix if present + if (line.substr(0, 6) == "data: ") { + line = line.substr(6); + } + + // Skip [DONE] message + if (line == "[DONE]") { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + context->callback(std::move(status), Json::Value()); + continue; + } + + // Parse the JSON + Json::Value chunk_json; + Json::Reader reader; + if (reader.parse(line, chunk_json)) { + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + context->callback(std::move(status), std::move(chunk_json)); + } + } + + return size * nmemb; +} + struct CurlResponse { std::string body; bool error{false}; @@ -43,11 +98,13 @@ class RemoteEngine : public EngineI { CurlResponse MakeChatCompletionRequest(const ModelConfig& config, const std::string& body, const std::string& method = "POST"); + CurlResponse MakeStreamingChatCompletionRequest( + const ModelConfig& config, const std::string& body, + std::function callback); CurlResponse MakeGetModelsRequest(); // Internal model management - bool LoadModelConfig(const std::string& model, - const std::string& yaml_path, + bool LoadModelConfig(const std::string& model, const std::string& yaml_path, const std::string& api_key); ModelConfig* GetModelConfig(const std::string& model); diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index a4ea44613..49699da31 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -3,6 +3,7 @@ #include #include #include "algorithm" +#include "database/engines.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -13,8 +14,6 @@ #include "utils/semantic_version_utils.h" #include "utils/system_info_utils.h" #include "utils/url_parser.h" -#include "database/engines.h" - namespace { std::string GetSuitableCudaVersion(const std::string& engine, const std::string& cuda_driver_version) { @@ -230,7 +229,9 @@ cpp::result EngineService::DownloadEngineV2( const std::string& engine, const std::string& version, const std::optional variant_name) { - auto normalized_version = version == "latest" ? "latest" : string_utils::RemoveSubstring(version, "v"); + auto normalized_version = version == "latest" + ? "latest" + : string_utils::RemoveSubstring(version, "v"); auto res = GetEngineVariants(engine, version); if (res.has_error()) { return cpp::fail("Failed to fetch engine releases: " + res.error()); @@ -241,8 +242,11 @@ cpp::result EngineService::DownloadEngineV2( std::optional selected_variant = std::nullopt; if (variant_name.has_value()) { - auto latest_version_semantic = normalized_version == "latest" ? res.value()[0].version : normalized_version; - auto merged_variant_name = engine + "-" + latest_version_semantic + "-" + variant_name.value() + ".tar.gz"; + auto latest_version_semantic = normalized_version == "latest" + ? res.value()[0].version + : normalized_version; + auto merged_variant_name = engine + "-" + latest_version_semantic + "-" + + variant_name.value() + ".tar.gz"; for (const auto& asset : res.value()) { if (asset.name == merged_variant_name) { @@ -281,24 +285,31 @@ cpp::result EngineService::DownloadEngineV2( } auto normalize_version = "v" + selected_variant->version; - auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion(selected_variant->name, engine, selected_variant->version); - auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() / engine / variant_folder_name.value() / normalize_version; + auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion( + selected_variant->name, engine, selected_variant->version); + auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() / + engine / variant_folder_name.value() / + normalize_version; auto variant_path = variant_folder_path / selected_variant->name; std::filesystem::create_directories(variant_folder_path); CLI_LOG("variant_folder_path: " + variant_folder_path.string()); - auto on_finished = [this, engine, selected_variant, variant_folder_path, normalize_version](const DownloadTask& finishedTask) { + auto on_finished = [this, engine, selected_variant, variant_folder_path, + normalize_version](const DownloadTask& finishedTask) { CLI_LOG("Engine zip path: " << finishedTask.items[0].localPath.string()); CLI_LOG("Version: " + normalize_version); auto extract_path = finishedTask.items[0].localPath.parent_path(); - archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), extract_path.string(), true); + archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), + extract_path.string(), true); + + auto variant = engine_matcher_utils::GetVariantFromNameAndVersion( + selected_variant->name, engine, normalize_version); - auto variant = engine_matcher_utils::GetVariantFromNameAndVersion(selected_variant->name, engine, normalize_version); - CLI_LOG("Extracted variant: " + variant.value()); - auto res = SetDefaultEngineVariant(engine, normalize_version, variant.value()); + auto res = + SetDefaultEngineVariant(engine, normalize_version, variant.value()); if (res.has_error()) { CTL_ERR("Failed to set default engine variant: " << res.error()); } else { @@ -307,15 +318,13 @@ cpp::result EngineService::DownloadEngineV2( // Create engine entry in the database cortex::db::Engines engines; - auto create_res = engines.UpsertEngine( - engine, // engine_name - "", // todo - luke - "", // todo - luke - "", // todo - luke - normalize_version, - variant.value(), - "Default", // todo - luke - "" // todo - luke + auto create_res = engines.UpsertEngine(engine, // engine_name + "", // todo - luke + "", // todo - luke + "", // todo - luke + normalize_version, variant.value(), + "Default", // todo - luke + "" // todo - luke ); if (create_res.has_value()) { @@ -324,8 +333,10 @@ cpp::result EngineService::DownloadEngineV2( CTL_INF("Engine entry created successfully"); } - for (const auto& entry : std::filesystem::directory_iterator(variant_folder_path.parent_path())) { - if (entry.is_directory() && entry.path().filename() != normalize_version) { + for (const auto& entry : std::filesystem::directory_iterator( + variant_folder_path.parent_path())) { + if (entry.is_directory() && + entry.path().filename() != normalize_version) { try { std::filesystem::remove_all(entry.path()); } catch (const std::exception& e) { @@ -342,20 +353,19 @@ cpp::result EngineService::DownloadEngineV2( CTL_INF("Finished!"); }; - auto downloadTask = DownloadTask{ - .id = engine, - .type = DownloadType::Engine, - .items = {DownloadItem{ - .id = engine, - .downloadUrl = selected_variant->browser_download_url, - .localPath = variant_path, - }} - }; + auto downloadTask = + DownloadTask{.id = engine, + .type = DownloadType::Engine, + .items = {DownloadItem{ + .id = engine, + .downloadUrl = selected_variant->browser_download_url, + .localPath = variant_path, + }}}; auto add_task_result = download_service_->AddTask(downloadTask, on_finished); if (add_task_result.has_error()) { return cpp::fail(add_task_result.error()); - } + } return {}; } @@ -750,6 +760,30 @@ cpp::result EngineService::LoadEngine( return {}; } + // TODO (alex): Need to change this, now Hard code for testing remote engine + if (engine_name != kLlamaEngine && engine_name != kOnnxEngine && + engine_name != kTrtLlmEngine) { + engines_[engine_name].engine = new remote_engine::RemoteEngine(); + auto& en = std::get(engines_[ne].engine); + auto config = file_manager_utils::GetCortexConfig(); + if (en->IsSupported("SetFileLogger")) { + en->SetFileLogger(config.maxLogLines, + (std::filesystem::path(config.logFolderPath) / + std::filesystem::path(config.logLlamaCppPath)) + .string()); + } else { + CTL_WRN("Method SetFileLogger is not supported yet"); + } + if (en->IsSupported("SetLogLevel")) { + en->SetLogLevel(trantor::Logger::logLevel()); + } else { + CTL_WRN("Method SetLogLevel is not supported yet"); + } + return {}; + } + + // End hard code + CTL_INF("Loading engine: " << ne); auto selected_engine_variant = GetDefaultEngineVariant(ne); @@ -934,6 +968,13 @@ cpp::result EngineService::IsEngineReady( const std::string& engine) const { auto ne = NormalizeEngine(engine); + // Hard code to test remote engine + if (engine != kLlamaRepo && engine != kTrtLlmRepo && engine != kOnnxRepo) { + return true; + } + + // End hard code + auto os = hw_inf_.sys_inf->os; if (os == kMacOs && (ne == kOnnxRepo || ne == kTrtLlmRepo)) { return cpp::fail("Engine " + engine + " is not supported on macOS"); diff --git a/engine/utils/remote_models_utils.h b/engine/utils/remote_models_utils.h index 6c13f62c1..7b7906f2c 100644 --- a/engine/utils/remote_models_utils.h +++ b/engine/utils/remote_models_utils.h @@ -1,102 +1,132 @@ #pragma once -#include #include +#include #include namespace remote_models_utils { +constexpr char chat_completion_request_template[] = + "{ {% set first = true %} {% for key, value in input_request %} {% if key " + "== \"messages\" or key == \"model\" or key == \"temperature\" or key == " + "\"store\" or key == \"max_tokens\" or key == \"stream\" or key == " + "\"presence_penalty\" or key == \"metadata\" or key == " + "\"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or " + "key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" " + "or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key " + "== \"response_format\" or key == \"service_tier\" or key == \"seed\" or " + "key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key " + "== \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% " + "endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% " + "endif %} {% endfor %} }"; + +constexpr char chat_completion_response_template[] = + "{ {% set first = true %} {% for key, value in input_request %} {% if key " + "== \"messages\" or key == \"model\" or key == \"temperature\" or key == " + "\"store\" or key == \"max_tokens\" or key == \"stream\" or key == " + "\"presence_penalty\" or key == \"metadata\" or key == " + "\"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or " + "key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" " + "or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key " + "== \"response_format\" or key == \"service_tier\" or key == \"seed\" or " + "key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key " + "== \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% " + "endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% " + "endif %} {% endfor %} }"; + +constexpr char chat_completion_url[] = + "https://api.openai.com/v1/chat/completions"; inline Json::Value yamlToJson(const YAML::Node& node) { - Json::Value result; + Json::Value result; - switch (node.Type()) { - case YAML::NodeType::Null: - return Json::Value(); - case YAML::NodeType::Scalar: { - // For scalar types, we'll first try to parse as string - std::string str_val = node.as(); - - // Try to parse as boolean - if (str_val == "true" || str_val == "True" || str_val == "TRUE") - return Json::Value(true); - if (str_val == "false" || str_val == "False" || str_val == "FALSE") - return Json::Value(false); - - // Try to parse as number - try { - // Check if it's an integer - size_t pos; - long long int_val = std::stoll(str_val, &pos); - if (pos == str_val.length()) { - return Json::Value(static_cast(int_val)); - } + switch (node.Type()) { + case YAML::NodeType::Null: + return Json::Value(); + case YAML::NodeType::Scalar: { + // For scalar types, we'll first try to parse as string + std::string str_val = node.as(); - // Check if it's a float - double float_val = std::stod(str_val, &pos); - if (pos == str_val.length()) { - return Json::Value(float_val); - } - } catch (...) { - // If parsing as number fails, use as string - } - - // Default to string if no other type matches - return Json::Value(str_val); - } - case YAML::NodeType::Sequence: { - result = Json::Value(Json::arrayValue); - for (const auto& elem : node) { - result.append(yamlToJson(elem)); - } - return result; + // Try to parse as boolean + if (str_val == "true" || str_val == "True" || str_val == "TRUE") + return Json::Value(true); + if (str_val == "false" || str_val == "False" || str_val == "FALSE") + return Json::Value(false); + + // Try to parse as number + try { + // Check if it's an integer + size_t pos; + long long int_val = std::stoll(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(static_cast(int_val)); } - case YAML::NodeType::Map: { - result = Json::Value(Json::objectValue); - for (const auto& it : node) { - std::string key = it.first.as(); - result[key] = yamlToJson(it.second); - } - return result; + + // Check if it's a float + double float_val = std::stod(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(float_val); } - default: - return Json::Value(); + } catch (...) { + // If parsing as number fails, use as string + } + + // Default to string if no other type matches + return Json::Value(str_val); + } + case YAML::NodeType::Sequence: { + result = Json::Value(Json::arrayValue); + for (const auto& elem : node) { + result.append(yamlToJson(elem)); + } + return result; } + case YAML::NodeType::Map: { + result = Json::Value(Json::objectValue); + for (const auto& it : node) { + std::string key = it.first.as(); + result[key] = yamlToJson(it.second); + } + return result; + } + default: + return Json::Value(); + } } inline YAML::Node jsonToYaml(const Json::Value& json) { - YAML::Node result; - - switch (json.type()) { - case Json::nullValue: - result = YAML::Node(YAML::NodeType::Null); - break; - case Json::intValue: - result = json.asInt64(); - break; - case Json::uintValue: - result = json.asUInt64(); - break; - case Json::realValue: - result = json.asDouble(); - break; - case Json::stringValue: - result = json.asString(); - break; - case Json::booleanValue: - result = json.asBool(); - break; - case Json::arrayValue: - result = YAML::Node(YAML::NodeType::Sequence); - for (const auto& elem : json) - result.push_back(jsonToYaml(elem)); - break; - case Json::objectValue: - result = YAML::Node(YAML::NodeType::Map); - for (const auto& key : json.getMemberNames()) - result[key] = jsonToYaml(json[key]); - break; - } - return result; + YAML::Node result; + + switch (json.type()) { + case Json::nullValue: + result = YAML::Node(YAML::NodeType::Null); + break; + case Json::intValue: + result = json.asInt64(); + break; + case Json::uintValue: + result = json.asUInt64(); + break; + case Json::realValue: + result = json.asDouble(); + break; + case Json::stringValue: + result = json.asString(); + break; + case Json::booleanValue: + result = json.asBool(); + break; + case Json::arrayValue: + result = YAML::Node(YAML::NodeType::Sequence); + for (const auto& elem : json) + result.push_back(jsonToYaml(elem)); + break; + case Json::objectValue: + result = YAML::Node(YAML::NodeType::Map); + for (const auto& key : json.getMemberNames()) + result[key] = jsonToYaml(json[key]); + break; + } + return result; } -} // namespace utils \ No newline at end of file +} // namespace remote_models_utils \ No newline at end of file From ca3972ef2aa612c211603e688551c9540727b411 Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Thu, 14 Nov 2024 07:26:29 +0700 Subject: [PATCH 15/33] refine db schema, composite key for engines --- engine/database/engines.cc | 33 +++++++++++++++++++++---------- engine/services/engine_service.cc | 1 + 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/engine/database/engines.cc b/engine/database/engines.cc index f2a5c4d09..1b38b2479 100644 --- a/engine/database/engines.cc +++ b/engine/database/engines.cc @@ -2,6 +2,7 @@ #include #include #include "database.h" + namespace cortex::db { void CreateTable(SQLite::Database& db) { @@ -15,8 +16,11 @@ void CreateTable(SQLite::Database& db) { "version TEXT," "variant TEXT," "status TEXT," - "metadata TEXT);"); -} + "metadata TEXT," + "date_created TEXT DEFAULT CURRENT_TIMESTAMP," + "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," + "UNIQUE(engine_name, variant));"); // Add UNIQUE constraint +} Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { CreateTable(db_); @@ -28,8 +32,6 @@ Engines::Engines(SQLite::Database& db) : db_(db) { Engines::~Engines() {} - - // Function to create a new engine and save it into the database std::optional Engines::UpsertEngine(const std::string& engine_name, const std::string& type, @@ -40,9 +42,17 @@ std::optional Engines::UpsertEngine(const std::string& engine_name, const std::string& status, const std::string& metadata) { try { - SQLite::Statement query(db_, + SQLite::Statement query(db_, "INSERT INTO engines (engine_name, type, api_key, url, version, variant, status, metadata) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?)"); + "VALUES (?, ?, ?, ?, ?, ?, ?, ?) " + "ON CONFLICT(engine_name, variant) DO UPDATE SET " + "type = excluded.type, " + "api_key = excluded.api_key, " + "url = excluded.url, " + "version = excluded.version, " + "status = excluded.status, " + "metadata = excluded.metadata, " + "date_updated = CURRENT_TIMESTAMP;"); query.bind(1, engine_name); query.bind(2, type); @@ -56,14 +66,16 @@ std::optional Engines::UpsertEngine(const std::string& engine_name, query.exec(); return std::nullopt; } catch (const std::exception& e) { - return std::string("Failed to create engine: ") + e.what(); + return std::string("Failed to upsert engine: ") + e.what(); } } std::optional Engines::GetEngine(int id, const std::string& engine_name) const { try { - SQLite::Statement query(db_, - "SELECT engine_name FROM engines WHERE (id = ? OR engine_name = ?) AND status = 'Default' LIMIT 1"); + SQLite::Statement query(db_, + "SELECT engine_name FROM engines " + "WHERE (id = ? OR engine_name = ?) AND status = 'Default' " + "ORDER BY date_updated DESC LIMIT 1"); query.bind(1, id); query.bind(2, engine_name); @@ -80,7 +92,7 @@ std::optional Engines::GetEngine(int id, const std::string& engine_ std::optional Engines::DeleteEngine(int id) { try { - SQLite::Statement query(db_, + SQLite::Statement query(db_, "DELETE FROM engines WHERE id = ?"); query.bind(1, id); @@ -90,4 +102,5 @@ std::optional Engines::DeleteEngine(int id) { return std::string("Failed to delete engine: ") + e.what(); } } + } // namespace cortex::db \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 49699da31..74142bc54 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -4,6 +4,7 @@ #include #include "algorithm" #include "database/engines.h" +#include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" From bedb8030dd77ec7ad8fe8b14a45ab16a1e2cf1f9 Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Thu, 14 Nov 2024 07:46:31 +0700 Subject: [PATCH 16/33] add entry definition for engine at db layer --- engine/database/engines.cc | 44 +++++++++++++++++++++++++------ engine/database/engines.h | 15 ++++++++--- engine/services/engine_service.cc | 2 +- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/engine/database/engines.cc b/engine/database/engines.cc index 1b38b2479..1f38ae3b7 100644 --- a/engine/database/engines.cc +++ b/engine/database/engines.cc @@ -32,8 +32,7 @@ Engines::Engines(SQLite::Database& db) : db_(db) { Engines::~Engines() {} -// Function to create a new engine and save it into the database -std::optional Engines::UpsertEngine(const std::string& engine_name, +std::optional Engines::UpsertEngine(const std::string& engine_name, const std::string& type, const std::string& api_key, const std::string& url, @@ -52,7 +51,8 @@ std::optional Engines::UpsertEngine(const std::string& engine_name, "version = excluded.version, " "status = excluded.status, " "metadata = excluded.metadata, " - "date_updated = CURRENT_TIMESTAMP;"); + "date_updated = CURRENT_TIMESTAMP " + "RETURNING id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated;"); query.bind(1, engine_name); query.bind(2, type); @@ -63,17 +63,33 @@ std::optional Engines::UpsertEngine(const std::string& engine_name, query.bind(7, status); query.bind(8, metadata); - query.exec(); - return std::nullopt; + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), + query.getColumn(1).getString(), + query.getColumn(2).getString(), + query.getColumn(3).getString(), + query.getColumn(4).getString(), + query.getColumn(5).getString(), + query.getColumn(6).getString(), + query.getColumn(7).getString(), + query.getColumn(8).getString(), + query.getColumn(9).getString(), + query.getColumn(10).getString() + }; + } else { + return std::nullopt; + } } catch (const std::exception& e) { - return std::string("Failed to upsert engine: ") + e.what(); + return std::nullopt; } } std::optional Engines::GetEngine(int id, const std::string& engine_name) const { try { SQLite::Statement query(db_, - "SELECT engine_name FROM engines " + "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " + "FROM engines " "WHERE (id = ? OR engine_name = ?) AND status = 'Default' " "ORDER BY date_updated DESC LIMIT 1"); @@ -81,7 +97,19 @@ std::optional Engines::GetEngine(int id, const std::string& engine_ query.bind(2, engine_name); if (query.executeStep()) { - return EngineEntry{query.getColumn(0).getString()}; + return EngineEntry{ + query.getColumn(0).getInt(), + query.getColumn(1).getString(), + query.getColumn(2).getString(), + query.getColumn(3).getString(), + query.getColumn(4).getString(), + query.getColumn(5).getString(), + query.getColumn(6).getString(), + query.getColumn(7).getString(), + query.getColumn(8).getString(), + query.getColumn(9).getString(), + query.getColumn(10).getString() + }; } else { return std::nullopt; } diff --git a/engine/database/engines.h b/engine/database/engines.h index 927b2679d..93d126eee 100644 --- a/engine/database/engines.h +++ b/engine/database/engines.h @@ -9,11 +9,20 @@ namespace cortex::db { struct EngineEntry { - std::string engine; + int id; + std::string engine_name; + std::string type; + std::string api_key; + std::string url; + std::string version; + std::string variant; + std::string status; + std::string metadata; + std::string date_created; + std::string date_updated; }; class Engines { - private: SQLite::Database& db_; @@ -28,7 +37,7 @@ class Engines { Engines(SQLite::Database& db); ~Engines(); - std::optional UpsertEngine(const std::string& engine_name, + std::optional UpsertEngine(const std::string& engine_name, const std::string& type, const std::string& api_key, const std::string& url, diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 74142bc54..8e46078f7 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -329,7 +329,7 @@ cpp::result EngineService::DownloadEngineV2( ); if (create_res.has_value()) { - CTL_ERR("Failed to create engine entry: " << create_res.value()); + CTL_ERR("Failed to create engine entry: " << create_res->engine_name); } else { CTL_INF("Engine entry created successfully"); } From a10294e27d457c28f52b3aee7e1e4600e45bdb00 Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Thu, 14 Nov 2024 08:06:53 +0700 Subject: [PATCH 17/33] complete add, get engine operations --- engine/database/engines.cc | 84 ++++++++++++++++++++++++++++--- engine/database/engines.h | 6 ++- engine/services/engine_service.cc | 51 +++++++++++++++++++ engine/services/engine_service.h | 6 +++ 4 files changed, 139 insertions(+), 8 deletions(-) diff --git a/engine/database/engines.cc b/engine/database/engines.cc index 1f38ae3b7..5fa71bbf2 100644 --- a/engine/database/engines.cc +++ b/engine/database/engines.cc @@ -1,5 +1,4 @@ #include "engines.h" -#include #include #include "database.h" @@ -19,7 +18,7 @@ void CreateTable(SQLite::Database& db) { "metadata TEXT," "date_created TEXT DEFAULT CURRENT_TIMESTAMP," "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," - "UNIQUE(engine_name, variant));"); // Add UNIQUE constraint + "UNIQUE(engine_name, variant));"); } Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { @@ -85,16 +84,46 @@ std::optional Engines::UpsertEngine(const std::string& engine_name, } } -std::optional Engines::GetEngine(int id, const std::string& engine_name) const { +std::optional> Engines::GetEngines() const { try { SQLite::Statement query(db_, "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " "FROM engines " - "WHERE (id = ? OR engine_name = ?) AND status = 'Default' " + "WHERE status = 'Default' " + "ORDER BY date_updated DESC"); + + std::vector engines; + while (query.executeStep()) { + engines.push_back(EngineEntry{ + query.getColumn(0).getInt(), + query.getColumn(1).getString(), + query.getColumn(2).getString(), + query.getColumn(3).getString(), + query.getColumn(4).getString(), + query.getColumn(5).getString(), + query.getColumn(6).getString(), + query.getColumn(7).getString(), + query.getColumn(8).getString(), + query.getColumn(9).getString(), + query.getColumn(10).getString() + }); + } + + return engines; + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::GetEngineById(int id) const { + try { + SQLite::Statement query(db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " + "FROM engines " + "WHERE id = ? AND status = 'Default' " "ORDER BY date_updated DESC LIMIT 1"); query.bind(1, id); - query.bind(2, engine_name); if (query.executeStep()) { return EngineEntry{ @@ -118,7 +147,50 @@ std::optional Engines::GetEngine(int id, const std::string& engine_ } } -std::optional Engines::DeleteEngine(int id) { +std::optional Engines::GetEngineByNameAndVariant(const std::string& engine_name, const std::optional variant) const { + try { + std::string queryStr = + "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " + "FROM engines " + "WHERE engine_name = ? AND status = 'Default' "; + + if (variant) { + queryStr += "AND variant = ? "; + } + + queryStr += "ORDER BY date_updated DESC LIMIT 1"; + + SQLite::Statement query(db_, queryStr); + + query.bind(1, engine_name); + + if (variant) { + query.bind(2, variant.value()); + } + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), + query.getColumn(1).getString(), + query.getColumn(2).getString(), + query.getColumn(3).getString(), + query.getColumn(4).getString(), + query.getColumn(5).getString(), + query.getColumn(6).getString(), + query.getColumn(7).getString(), + query.getColumn(8).getString(), + query.getColumn(9).getString(), + query.getColumn(10).getString() + }; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::DeleteEngineById(int id) { try { SQLite::Statement query(db_, "DELETE FROM engines WHERE id = ?"); diff --git a/engine/database/engines.h b/engine/database/engines.h index 93d126eee..4f5c48c86 100644 --- a/engine/database/engines.h +++ b/engine/database/engines.h @@ -46,9 +46,11 @@ class Engines { const std::string& status, const std::string& metadata); - std::optional GetEngine(int id, const std::string& engine_name) const; + std::optional> GetEngines() const; + std::optional GetEngineById(int id) const; + std::optional GetEngineByNameAndVariant(const std::string& engine_name, const std::optional variant = std::nullopt) const; - std::optional DeleteEngine(int id); + std::optional DeleteEngineById(int id); }; } // namespace cortex::db \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 8e46078f7..d81ddfb23 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include "algorithm" #include "database/engines.h" #include "extensions/remote-engine/remote_engine.h" @@ -1061,3 +1062,53 @@ cpp::result EngineService::UpdateEngine( .from = default_variant->version, .to = latest_version->tag_name}; } + + +cpp::result, std::string> GetEngineEntries() { + cortex::db::Engines engines; + auto get_res = engines.GetEngines(); + + if (!get_res.has_value()) { + return cpp::fail("Failed to get engine entries"); + } + + return get_res.value(); +} + +cpp::result GetEngineEntryById(int id) { + cortex::db::Engines engines; + auto get_res = engines.GetEngineById(id); + + if (!get_res.has_value()) { + return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); + } + + return get_res.value(); +} + +cpp::result GetEngineEntryByNameAndVariant( + const std::string& engine_name, const std::optional variant = std::nullopt) { + + cortex::db::Engines engines; + auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); + + if (!get_res.has_value()) { + if (variant.has_value()) { + return cpp::fail("Variant " + variant.value() + " not found for engine " + engine_name); + } else { + return cpp::fail("Engine " + engine_name + " not found"); + } + } + + return get_res.value(); +} + +std::string DeleteEngineEntryById(int id) { + cortex::db::Engines engines; + auto delete_res = engines.DeleteEngineById(id); + if (delete_res.has_value()) { + return delete_res.value(); + } else { + return "Failed to delete engine entry"; + } +} \ No newline at end of file diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 916770a06..32a4d38f8 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -9,6 +9,7 @@ #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" +#include "database/engines.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -124,6 +125,11 @@ class EngineService : public EngineServiceI { cpp::result UpdateEngine( const std::string& engine); + cpp::result GetEngineEntryById(int id); + + cpp::result GetEngineEntryByNameAndVariant( + const std::string& engine_name, const std::optional variant); + private: cpp::result DownloadEngine( const std::string& engine, const std::string& version = "latest", From 5f9e706af66091e77202d75e6ba106eb21982881 Mon Sep 17 00:00:00 2001 From: Luke Nguyen Date: Thu, 14 Nov 2024 08:28:57 +0700 Subject: [PATCH 18/33] engine managements --- engine/services/engine_service.cc | 45 +++++++++++++++++++++---------- engine/services/engine_service.h | 18 +++++++++++-- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index d81ddfb23..964f4a86c 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -317,16 +317,15 @@ cpp::result EngineService::DownloadEngineV2( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - - // Create engine entry in the database - cortex::db::Engines engines; - auto create_res = engines.UpsertEngine(engine, // engine_name - "", // todo - luke - "", // todo - luke - "", // todo - luke - normalize_version, variant.value(), - "Default", // todo - luke - "" // todo - luke + auto create_res = EngineService::UpsertEngine( + engine, // engine_name + "", // todo - luke + "", // todo - luke + "", // todo - luke + normalize_version, + variant.value(), + "Default", // todo - luke + "" // todo - luke ); if (create_res.has_value()) { @@ -1064,7 +1063,7 @@ cpp::result EngineService::UpdateEngine( } -cpp::result, std::string> GetEngineEntries() { +cpp::result, std::string> EngineService::GetEngines() { cortex::db::Engines engines; auto get_res = engines.GetEngines(); @@ -1075,7 +1074,7 @@ cpp::result, std::string> GetEngineEntries( return get_res.value(); } -cpp::result GetEngineEntryById(int id) { +cpp::result EngineService::GetEngineById(int id) { cortex::db::Engines engines; auto get_res = engines.GetEngineById(id); @@ -1086,7 +1085,7 @@ cpp::result GetEngineEntryById(int id) { return get_res.value(); } -cpp::result GetEngineEntryByNameAndVariant( +cpp::result EngineService::GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant = std::nullopt) { cortex::db::Engines engines; @@ -1103,7 +1102,25 @@ cpp::result GetEngineEntryByNameAndVariant return get_res.value(); } -std::string DeleteEngineEntryById(int id) { +cpp::result EngineService::UpsertEngine( + const std::string& engine_name, + const std::string& type, + const std::string& api_key, + const std::string& url, + const std::string& version, + const std::string& variant, + const std::string& status, + const std::string& metadata) { + cortex::db::Engines engines; + auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, version, variant, status, metadata); + if (upsert_res.has_value()) { + return upsert_res.value(); + } else { + return cpp::fail("Failed to upsert engine entry"); + } +} + +std::string EngineService::DeleteEngine(int id) { cortex::db::Engines engines; auto delete_res = engines.DeleteEngineById(id); if (delete_res.has_value()) { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 32a4d38f8..6081e1713 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -125,11 +125,25 @@ class EngineService : public EngineServiceI { cpp::result UpdateEngine( const std::string& engine); - cpp::result GetEngineEntryById(int id); + cpp::result, std::string> GetEngines(); - cpp::result GetEngineEntryByNameAndVariant( + cpp::result GetEngineById(int id); + + cpp::result GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant); + cpp::result UpsertEngine( + const std::string& engine_name, + const std::string& type, + const std::string& api_key, + const std::string& url, + const std::string& version, + const std::string& variant, + const std::string& status, + const std::string& metadata); + + std::string DeleteEngine(int id); + private: cpp::result DownloadEngine( const std::string& engine, const std::string& version = "latest", From cab8b442ef15e07ee81599d597d33e13fd8fe736 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Thu, 14 Nov 2024 18:09:14 +0700 Subject: [PATCH 19/33] Integrate with remote engine to run remote model --- engine/common/engine_servicei.h | 6 +- engine/controllers/engines.cc | 94 +++++++++++++++++- engine/controllers/models.cc | 2 +- engine/database/engines.h | 74 ++++++++++---- engine/services/engine_service.cc | 154 +++++++++++++++++------------- engine/services/engine_service.h | 25 +++-- engine/services/model_service.cc | 47 ++++++++- 7 files changed, 296 insertions(+), 106 deletions(-) diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index bd4f099ab..85fa87d76 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -3,8 +3,8 @@ #include #include #include +#include "database/engines.h" #include "utils/result.hpp" - // TODO: namh think of the other name struct DefaultEngineVariant { std::string engine; @@ -54,4 +54,8 @@ class EngineServiceI { virtual cpp::result UnloadEngine( const std::string& engine_name) = 0; + virtual cpp::result + GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) = 0; }; diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index a75bd1f9b..c4a2fc058 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -3,9 +3,9 @@ #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" +#include "utils/http_util.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" - namespace { // Need to change this after we rename repositories std::string NormalizeEngine(const std::string& engine) { @@ -38,6 +38,18 @@ void Engines::ListEngine( } ret[engine] = variants; } + // Add remote engine + auto remote_engines = engine_service_->GetEngines(); + if (remote_engines.has_value()) { + for (auto engine : remote_engines.value()) { + if (engine.type == "remote") { + auto engine_json = engine.ToJson(); + Json::Value list_engine(Json::arrayValue); + list_engine.append(engine_json); + ret[engine.engine_name] = list_engine; + } + } + } auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); @@ -137,6 +149,86 @@ void Engines::InstallEngine( const std::string& engine, const std::optional version, const std::optional variant_name) { auto normalized_version = version.value_or("latest"); + if ((*(req->getJsonObject())).get("type", "").asString() == "remote") { + + auto type = (*(req->getJsonObject())).get("type", "").asString(); + auto api_key = (*(req->getJsonObject())).get("api_key", "").asString(); + auto url = (*(req->getJsonObject())).get("url", "").asString(); + auto variant = variant_name.value_or("all-platforms"); + auto status = (*(req->getJsonObject())).get("status", "Default").asString(); + std::string metadata; + if ((*(req->getJsonObject())).isMember("metadata") && + (*(req->getJsonObject()))["metadata"].isObject()) { + metadata = (*(req->getJsonObject())) + .get("metadata", Json::Value(Json::objectValue)) + .toStyledString(); + } else if ((*(req->getJsonObject())).isMember("metadata") && + !(*(req->getJsonObject()))["metadata"].isObject()) { + Json::Value res; + res["message"] = "metadata must be object"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto get_models_url = (*(req->getJsonObject())) + .get("metadata", Json::Value(Json::objectValue)) + .get("get_models_url", "") + .asString(); + + if (engine.empty() || type.empty() || url.empty()) { + Json::Value res; + res["message"] = "Engine name, type, url are required"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto exist_engine = engine_service_->GetEngineByNameAndVariant(engine); + // only allow 1 variant 1 version of a remote engine name + if (exist_engine.has_value()) { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = "Engine '" + engine + "' already exists"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->UpsertEngine(engine, type, api_key, url, + normalized_version, variant, + status, metadata); + if (result.has_error()) { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = "Remote Engine install successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + callback(resp); + } + return; + } auto result = engine_service_->InstallEngineAsyncV2( engine, normalized_version, variant_name); diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index dee126ebe..188ea30d9 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -534,7 +534,7 @@ void Models::StartModel( auto& v = result.value(); Json::Value ret; ret["message"] = "Started successfully!"; - if(v.warning) { + if (v.warning) { ret["warning"] = *(v.warning); } auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); diff --git a/engine/database/engines.h b/engine/database/engines.h index 4f5c48c86..7429d0fa2 100644 --- a/engine/database/engines.h +++ b/engine/database/engines.h @@ -1,25 +1,58 @@ #pragma once #include +#include #include +#include #include #include -#include namespace cortex::db { struct EngineEntry { - int id; - std::string engine_name; - std::string type; - std::string api_key; - std::string url; - std::string version; - std::string variant; - std::string status; - std::string metadata; - std::string date_created; - std::string date_updated; + int id; + std::string engine_name; + std::string type; + std::string api_key; + std::string url; + std::string version; + std::string variant; + std::string status; + std::string metadata; + std::string date_created; + std::string date_updated; + Json::Value ToJson() const { + Json::Value root; + Json::Reader reader; + + // Convert basic fields + root["id"] = id; + root["engine_name"] = engine_name; + root["type"] = type; + root["api_key"] = api_key; + root["url"] = url; + root["version"] = version; + root["variant"] = variant; + root["status"] = status; + root["date_created"] = date_created; + root["date_updated"] = date_updated; + + // Parse metadata string into JSON object + Json::Value metadataJson; + if (!metadata.empty()) { + bool success = reader.parse(metadata, metadataJson, + false); // false = don't collect comments + if (success) { + root["metadata"] = metadataJson; + } else { + root["metadata"] = Json::Value::null; + } + } else { + root["metadata"] = Json::Value(Json::objectValue); // empty object + } + + return root; + } }; class Engines { @@ -37,18 +70,17 @@ class Engines { Engines(SQLite::Database& db); ~Engines(); - std::optional UpsertEngine(const std::string& engine_name, - const std::string& type, - const std::string& api_key, - const std::string& url, - const std::string& version, - const std::string& variant, - const std::string& status, - const std::string& metadata); + std::optional UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); std::optional> GetEngines() const; std::optional GetEngineById(int id) const; - std::optional GetEngineByNameAndVariant(const std::string& engine_name, const std::optional variant = std::nullopt) const; + std::optional GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) const; std::optional DeleteEngineById(int id); }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 8d5d7a2ef..a305f3c27 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -200,6 +200,18 @@ cpp::result EngineService::UninstallEngineVariant( const std::string& engine, const std::optional version, const std::optional variant) { auto ne = NormalizeEngine(engine); + // TODO: handle uninstall remote engine + // only delete a remote engine if no model are using it + auto exist_engine = GetEngineByNameAndVariant(engine); + if (exist_engine.has_value() && exist_engine.value().type == "remote") { + auto result = DeleteEngine(exist_engine.value().id); + if (!result.empty()) { // This mean no error when delete model + CTL_ERR("Failed to delete engine: " << result); + return cpp::fail(result); + } + return cpp::result(true); + } + if (IsEngineLoaded(ne)) { CTL_INF("Engine " << ne << " is already loaded, unloading it"); auto unload_res = UnloadEngine(ne); @@ -337,16 +349,15 @@ cpp::result EngineService::DownloadEngineV2( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - auto create_res = EngineService::UpsertEngine( - engine, // engine_name - "", // todo - luke - "", // todo - luke - "", // todo - luke - normalize_version, - variant.value(), - "Default", // todo - luke - "" // todo - luke - ); + auto create_res = + EngineService::UpsertEngine(engine, // engine_name + "local", // todo - luke + "", // todo - luke + "", // todo - luke + normalize_version, variant.value(), + "Default", // todo - luke + "" // todo - luke + ); if (create_res.has_value()) { CTL_ERR("Failed to create engine entry: " << create_res->engine_name); @@ -781,9 +792,14 @@ cpp::result EngineService::LoadEngine( return {}; } - // TODO (alex): Need to change this, now Hard code for testing remote engine + // Check for remote engine if (engine_name != kLlamaEngine && engine_name != kOnnxEngine && engine_name != kTrtLlmEngine) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + engines_[engine_name].engine = new remote_engine::RemoteEngine(); auto& en = std::get(engines_[ne].engine); auto config = file_manager_utils::GetCortexConfig(); @@ -986,11 +1002,15 @@ EngineService::GetLatestEngineVersion(const std::string& engine) const { } cpp::result EngineService::IsEngineReady( - const std::string& engine) const { + const std::string& engine) { auto ne = NormalizeEngine(engine); - // Hard code to test remote engine + // Check for remote engine if (engine != kLlamaRepo && engine != kTrtLlmRepo && engine != kOnnxRepo) { + auto exist_engine = GetEngineByNameAndVariant(engine); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine + "' is not installed"); + } return true; } @@ -1082,70 +1102,70 @@ cpp::result EngineService::UpdateEngine( .to = latest_version->tag_name}; } +cpp::result, std::string> +EngineService::GetEngines() { + cortex::db::Engines engines; + auto get_res = engines.GetEngines(); -cpp::result, std::string> EngineService::GetEngines() { - cortex::db::Engines engines; - auto get_res = engines.GetEngines(); - - if (!get_res.has_value()) { - return cpp::fail("Failed to get engine entries"); - } - - return get_res.value(); + if (!get_res.has_value()) { + return cpp::fail("Failed to get engine entries"); + } + + return get_res.value(); } -cpp::result EngineService::GetEngineById(int id) { - cortex::db::Engines engines; - auto get_res = engines.GetEngineById(id); - - if (!get_res.has_value()) { - return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); - } - - return get_res.value(); +cpp::result EngineService::GetEngineById( + int id) { + cortex::db::Engines engines; + auto get_res = engines.GetEngineById(id); + + if (!get_res.has_value()) { + return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); + } + + return get_res.value(); } -cpp::result EngineService::GetEngineByNameAndVariant( - const std::string& engine_name, const std::optional variant = std::nullopt) { - - cortex::db::Engines engines; - auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); - - if (!get_res.has_value()) { - if (variant.has_value()) { - return cpp::fail("Variant " + variant.value() + " not found for engine " + engine_name); - } else { - return cpp::fail("Engine " + engine_name + " not found"); - } +cpp::result +EngineService::GetEngineByNameAndVariant( + const std::string& engine_name, const std::optional variant) { + + cortex::db::Engines engines; + auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); + + if (!get_res.has_value()) { + if (variant.has_value()) { + return cpp::fail("Variant " + variant.value() + " not found for engine " + + engine_name); + } else { + return cpp::fail("Engine " + engine_name + " not found"); } - - return get_res.value(); + } + + return get_res.value(); } cpp::result EngineService::UpsertEngine( - const std::string& engine_name, - const std::string& type, - const std::string& api_key, - const std::string& url, - const std::string& version, - const std::string& variant, - const std::string& status, - const std::string& metadata) { - cortex::db::Engines engines; - auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, version, variant, status, metadata); - if (upsert_res.has_value()) { - return upsert_res.value(); - } else { - return cpp::fail("Failed to upsert engine entry"); - } + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + cortex::db::Engines engines; + auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, + version, variant, status, metadata); + if (upsert_res.has_value()) { + return upsert_res.value(); + } else { + return cpp::fail("Failed to upsert engine entry"); + } } std::string EngineService::DeleteEngine(int id) { - cortex::db::Engines engines; - auto delete_res = engines.DeleteEngineById(id); - if (delete_res.has_value()) { - return delete_res.value(); - } else { - return "Failed to delete engine entry"; - } + cortex::db::Engines engines; + auto delete_res = engines.DeleteEngineById(id); + if (delete_res.has_value()) { + return delete_res.value(); + } else { + return ""; + } } \ No newline at end of file diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 6081e1713..75b5289ac 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -1,15 +1,17 @@ #pragma once #include +#include #include #include -#include #include -#include +#include +#include "common/engine_servicei.h" #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" #include "database/engines.h" +#include "extensions/remote-engine/remote_engine.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -17,8 +19,6 @@ #include "utils/github_release_utils.h" #include "utils/result.hpp" #include "utils/system_info_utils.h" -#include "common/engine_servicei.h" - namespace system_info_utils { struct SystemInfo; } @@ -76,7 +76,7 @@ class EngineService : public EngineServiceI { std::vector GetEngineInfoList() const; - cpp::result IsEngineReady(const std::string& engine) const; + cpp::result IsEngineReady(const std::string& engine); cpp::result InstallEngineAsync( const std::string& engine, const std::string& version = "latest", const std::string& src = ""); @@ -130,17 +130,14 @@ class EngineService : public EngineServiceI { cpp::result GetEngineById(int id); cpp::result GetEngineByNameAndVariant( - const std::string& engine_name, const std::optional variant); + const std::string& engine_name, + const std::optional variant = std::nullopt); cpp::result UpsertEngine( - const std::string& engine_name, - const std::string& type, - const std::string& api_key, - const std::string& url, - const std::string& version, - const std::string& variant, - const std::string& status, - const std::string& metadata); + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); std::string DeleteEngine(int id); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 3a8507c22..097c2f4ee 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -621,6 +621,50 @@ cpp::result ModelService::StartModel( .string()); auto mc = yaml_handler.GetModelConfig(); + // Running remote model + if (mc.engine != kLlamaEngine && mc.engine != kOnnxEngine && + mc.engine != kTrtLlmEngine) { + + config::RemoteModelConfig remote_mc; + remote_mc.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto remote_engine_entry = + engine_svc_->GetEngineByNameAndVariant(mc.engine); + if (remote_engine_entry.has_error()) { + CTL_WRN("Remote engine error: " + model_entry.error()); + return cpp::fail(remote_engine_entry.error()); + } + auto remote_engine_json = remote_engine_entry.value().ToJson(); + json_data = remote_mc.ToJson(); + + json_data["api_key"] = std::move(remote_engine_json["api_key"]); + json_data["model_path"] = + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string(); + json_data["metadata"] = std::move(remote_engine_json["metadata"]); + + auto ir = + inference_svc_->LoadModel(std::make_shared(json_data)); + auto status = std::get<0>(ir)["status_code"].asInt(); + auto data = std::get<1>(ir); + if (status == httplib::StatusCode::OK_200) { + return StartModelResult{.success = true, .warning = ""}; + } else if (status == httplib::StatusCode::Conflict_409) { + CTL_INF("Model '" + model_handle + "' is already loaded"); + return StartModelResult{.success = true, .warning = ""}; + } else { + // only report to user the error + CTL_ERR("Model failed to start with status code: " << status); + return cpp::fail("Model failed to start: " + + data["message"].asString()); + } + } + + // end hard code + json_data = mc.ToJson(); if (mc.files.size() > 0) { // TODO(sang) support multiple files @@ -745,7 +789,8 @@ cpp::result ModelService::StartModel( return cpp::fail( "Not enough VRAM - required: " + std::to_string(vram_needed_MiB) + " MiB, available: " + std::to_string(free_vram_MiB) + - " MiB - Should adjust ngl to " + std::to_string(free_vram_MiB / (vram_needed_MiB / ngl) - 1)); + " MiB - Should adjust ngl to " + + std::to_string(free_vram_MiB / (vram_needed_MiB / ngl) - 1)); } if (ram_needed_MiB > free_ram_MiB) { From 9d50f5fccb5fcd2b6dc4ed92a78a1b3d5b179245 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Fri, 15 Nov 2024 00:03:24 +0700 Subject: [PATCH 20/33] error handling and response transform --- .../extensions/remote-engine/remote_engine.cc | 113 ++++++++++++++++-- 1 file changed, 101 insertions(+), 12 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index c59c3e865..0fde13628 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -120,7 +120,6 @@ CurlResponse RemoteEngine::MakeGetModelsRequest() { struct curl_slist* headers = nullptr; headers = curl_slist_append(headers, api_key_template_.c_str()); - std::cout << "api_key: " << api_key_template_ << std::endl; headers = curl_slist_append(headers, "Content-Type: application/json"); @@ -162,7 +161,6 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( if (!config.api_key.empty()) { headers = curl_slist_append(headers, api_key_template_.c_str()); - std::cout << "api_key: " << api_key_template_ << std::endl; } headers = curl_slist_append(headers, "Content-Type: application/json"); @@ -386,15 +384,48 @@ void RemoteEngine::HandleChatCompletion( bool is_stream = json_body->isMember("stream") && (*json_body)["stream"].asBool(); Json::FastWriter writer; - std::string request_body = writer.write((*json_body)); - std::cout << "template: " - << model_config->transform_req["chat_completions"]["template"] - .as() - << std::endl; - std::string result = renderer_.render( - model_config->transform_req["chat_completions"]["template"] - .as(), - (*json_body)); + // Transform request + std::string result; + try { + // Check if required YAML nodes exist + if (!model_config->transform_req["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_req"); + } + if (!model_config->transform_req["chat_completions"]["template"]) { + throw std::runtime_error("Missing 'template' node in chat_completions"); + } + + // Validate JSON body + if (!json_body || json_body->isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + try { + template_str = model_config->transform_req["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error("Failed to convert template node to string: " + + std::string(e.what())); + } + + // Render with error handling + try { + result = renderer_.render(template_str, *json_body); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + result = (*json_body).toStyledString(); + } + if (is_stream) { MakeStreamingChatCompletionRequest(*model_config, result, callback); } else { @@ -426,13 +457,71 @@ void RemoteEngine::HandleChatCompletion( callback(std::move(status), std::move(error)); return; } + + // Transform Response + std::string response_str; + try { + // Check if required YAML nodes exist + if (!model_config->transform_resp["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_resp"); + } + if (!model_config->transform_resp["chat_completions"]["template"]) { + throw std::runtime_error("Missing 'template' node in chat_completions"); + } + + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + try { + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error("Failed to convert template node to string: " + + std::string(e.what())); + } + + // Render with error handling + try { + response_str = renderer_.render(template_str, response_json); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + response_str = response_json.toStyledString(); + } + + Json::Reader reader_final; + Json::Value response_json_final; + if (!reader_final.parse(response_str, response_json_final)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + Json::Value status; status["is_done"] = true; status["has_error"] = false; status["is_stream"] = false; status["status_code"] = k200OK; - callback(std::move(status), std::move(response_json)); + callback(std::move(status), std::move(response_json_final)); } } From fa434a4dd156ac314c0b04c73651673590bd0dcb Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Fri, 15 Nov 2024 12:47:34 +0700 Subject: [PATCH 21/33] Support for stream request --- .../extensions/remote-engine/remote_engine.cc | 10 +++--- .../extensions/remote-engine/remote_engine.h | 35 ++++++++++--------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 0fde13628..27a436ef7 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -4,6 +4,7 @@ #include #include namespace remote_engine { + constexpr const int k200OK = 200; constexpr const int k400BadRequest = 400; constexpr const int k409Conflict = 409; @@ -12,7 +13,7 @@ constexpr const int kFileLoggerOption = 0; CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, - std::function callback) { + const std::function& callback) { CURL* curl = curl_easy_init(); CurlResponse response; @@ -36,7 +37,10 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Cache-Control: no-cache"); headers = curl_slist_append(headers, "Connection: keep-alive"); - StreamContext context{callback, ""}; + StreamContext context{ + std::make_shared>( + callback), + ""}; curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); @@ -418,7 +422,6 @@ void RemoteEngine::HandleChatCompletion( throw std::runtime_error("Template rendering error: " + std::string(e.what())); } - } catch (const std::exception& e) { // Log error and potentially rethrow or handle accordingly LOG_WARN << "Error in TransformRequest: " << e.what(); @@ -493,7 +496,6 @@ void RemoteEngine::HandleChatCompletion( throw std::runtime_error("Template rendering error: " + std::string(e.what())); } - } catch (const std::exception& e) { // Log error and potentially rethrow or handle accordingly LOG_WARN << "Error in TransformRequest: " << e.what(); diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 3b6226d95..c181e832e 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -13,9 +13,8 @@ // Helper for CURL response namespace remote_engine { - struct StreamContext { - std::function callback; + std::shared_ptr> callback; std::string buffer; }; @@ -37,32 +36,34 @@ static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, continue; // Remove "data: " prefix if present - if (line.substr(0, 6) == "data: ") { - line = line.substr(6); - } + // if (line.substr(0, 6) == "data: ") + // { + // line = line.substr(6); + // } // Skip [DONE] message - if (line == "[DONE]") { + std::cout << line << std::endl; + if (line == "data: [DONE]") { Json::Value status; status["is_done"] = true; status["has_error"] = false; status["is_stream"] = true; status["status_code"] = 200; - context->callback(std::move(status), Json::Value()); - continue; + (*context->callback)(std::move(status), Json::Value()); + break; } // Parse the JSON Json::Value chunk_json; + chunk_json["data"] = line + "\n\n"; Json::Reader reader; - if (reader.parse(line, chunk_json)) { - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = 200; - context->callback(std::move(status), std::move(chunk_json)); - } + + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), std::move(chunk_json)); } return size * nmemb; @@ -100,7 +101,7 @@ class RemoteEngine : public EngineI { const std::string& method = "POST"); CurlResponse MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, - std::function callback); + const std::function& callback); CurlResponse MakeGetModelsRequest(); // Internal model management From 3f9f451ed0838aba6f7439caa9d523f9561df622 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 26 Nov 2024 08:56:42 +0700 Subject: [PATCH 22/33] chore: fix conflicts --- docs/docs/capabilities/embeddings.md | 8 -------- docs/static/openapi/cortex.json | 14 -------------- 2 files changed, 22 deletions(-) diff --git a/docs/docs/capabilities/embeddings.md b/docs/docs/capabilities/embeddings.md index c7d211739..44f153556 100644 --- a/docs/docs/capabilities/embeddings.md +++ b/docs/docs/capabilities/embeddings.md @@ -6,10 +6,6 @@ title: Embeddings ::: cortex.cpp now support embeddings endpoint with fully OpenAI compatible. -<<<<<<< HEAD - -======= ->>>>>>> a055f6906e43b3c3398874d540136425adae9114 For embeddings API usage please refer to [API references](/api-reference#tag/chat/POST/v1/embeddings). This tutorial show you how to use embeddings in cortex with openai python SDK. @@ -105,7 +101,3 @@ response = client.embeddings.create(input = [12,44,123], model=MODEL) # input as array of arrays contain tokens response = client.embeddings.create(input = [[912,312,54],[12,433,1241]], model=MODEL) ``` -<<<<<<< HEAD - -======= ->>>>>>> a055f6906e43b3c3398874d540136425adae9114 diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index d810fe09d..fdb5c4ed2 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -209,11 +209,7 @@ }, { "type": "array", -<<<<<<< HEAD - "description" : "The array of strings that will be turned into an embedding.", -======= "description": "The array of strings that will be turned into an embedding.", ->>>>>>> a055f6906e43b3c3398874d540136425adae9114 "items": { "type": "string" } @@ -223,21 +219,11 @@ "description": "The array of integers that will be turned into an embedding.", "items": { "type": "integer" -<<<<<<< HEAD - -======= ->>>>>>> a055f6906e43b3c3398874d540136425adae9114 } }, { "type": "array", -<<<<<<< HEAD - - "description" : "The array of arrays containing integers that will be turned into an embedding.", - -======= "description": "The array of arrays containing integers that will be turned into an embedding.", ->>>>>>> a055f6906e43b3c3398874d540136425adae9114 "items": { "type": "array", "items": { From 5473a75840a5b9ddec12b232a2fd71ead774d0f5 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 26 Nov 2024 14:00:00 +0700 Subject: [PATCH 23/33] feat: anthropic --- .../extensions/remote-engine/remote_engine.cc | 19 ++++++++++++++++++- .../extensions/remote-engine/remote_engine.h | 1 + 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 27a436ef7..374df3c55 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -3,13 +3,18 @@ #include #include #include +#include "utils/logging_utils.h" namespace remote_engine { - +namespace { constexpr const int k200OK = 200; constexpr const int k400BadRequest = 400; constexpr const int k409Conflict = 409; constexpr const int k500InternalServerError = 500; constexpr const int kFileLoggerOption = 0; +bool is_anthropic(const std::string& model) { + return model.find("claude") != std::string::npos; +} +} // namespace CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, @@ -166,6 +171,11 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( headers = curl_slist_append(headers, api_key_template_.c_str()); } + + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } headers = curl_slist_append(headers, "Content-Type: application/json"); curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); @@ -200,6 +210,13 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, ModelConfig model_config; model_config.model = model; + if (is_anthropic(model)) { + if (!config["version"]) { + CTL_ERR("Missing version for model: " << model); + return false; + } + model_config.version = config["version"].as(); + } // Required fields if (!config["api_key_template"]) { diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index c181e832e..46e65c852 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -80,6 +80,7 @@ class RemoteEngine : public EngineI { // Model configuration struct ModelConfig { std::string model; + std::string version; std::string api_key; std::string url; YAML::Node transform_req; From 071d84c80daefe04626b597ae6b564c487ed0cc5 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 27 Nov 2024 07:28:21 +0700 Subject: [PATCH 24/33] feat: support anthropic --- .../remote-engine/TemplateRenderer.cc | 1 + .../extensions/remote-engine/remote_engine.cc | 140 +++++++++++++++++- .../extensions/remote-engine/remote_engine.h | 55 +------ engine/services/engine_service.cc | 1 + engine/utils/logging_utils.h | 2 + 5 files changed, 146 insertions(+), 53 deletions(-) diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 3c3a0ea00..a000e925b 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -6,6 +6,7 @@ #include "TemplateRenderer.h" #include #include +#include "utils/logging_utils.h" namespace remote_engine { TemplateRenderer::TemplateRenderer() { // Configure Inja environment diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 374df3c55..96727c941 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -3,6 +3,7 @@ #include #include #include +#include "utils/json_helper.h" #include "utils/logging_utils.h" namespace remote_engine { namespace { @@ -14,8 +15,138 @@ constexpr const int kFileLoggerOption = 0; bool is_anthropic(const std::string& model) { return model.find("claude") != std::string::npos; } + +struct AnthropicChunk { + std::string type; + std::string id; + int index; + std::string msg; + std::string model; + std::string stop_reason; + bool should_ignore = false; + + AnthropicChunk(const std::string& str) { + if (str.size() > 6) { + std::string s = str.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + type = root["type"].asString(); + if (type == "message_start") { + id = root["message"]["id"].asString(); + model = root["message"]["model"].asString(); + } else if (type == "content_block_delta") { + index = root["index"].asInt(); + if (root["delta"]["type"].asString() == "text_delta") { + msg = root["delta"]["text"].asString(); + } + } else if (type == "message_delta") { + stop_reason = root["delta"]["stop_reason"].asString(); + } else { + // ignore other messages + should_ignore = true; + } + } catch (const std::exception& e) { + should_ignore = true; + CTL_WRN("JSON parse error: " << e.what()); + } + } else { + should_ignore = true; + } + } + + std::string ToOpenAiFormatString() { + Json::Value root; + root["id"] = id; + root["object"] = "chat.completion.chunk"; + root["created"] = Json::Value(); + root["model"] = model; + root["system_fingerprint"] = "fp_e76890f0c3"; + Json::Value choices(Json::arrayValue); + Json::Value choice; + Json::Value content; + choice["index"] = 0; + content["content"] = msg; + if (type == "message_start") { + content["role"] = "assistant"; + content["refusal"] = Json::Value(); + } + choice["delta"] = content; + choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; + choices.append(choice); + root["choices"] = choices; + return "data: " + json_helper::DumpJsonString(root); + } +}; + } // namespace +size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* context = static_cast(userdata); + std::string chunk(ptr, size * nmemb); + + context->buffer += chunk; + + // Process complete lines + size_t pos; + while ((pos = context->buffer.find('\n')) != std::string::npos) { + std::string line = context->buffer.substr(0, pos); + context->buffer = context->buffer.substr(pos + 1); + CTL_TRC(line); + + // Skip empty lines + if (line.empty() || line == "\r" || + line.find("event:") != std::string::npos) + continue; + + // Remove "data: " prefix if present + // if (line.substr(0, 6) == "data: ") + // { + // line = line.substr(6); + // } + + // Skip [DONE] message + // std::cout << line << std::endl; + if (line == "data: [DONE]" || + line.find("message_stop") != std::string::npos) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), Json::Value()); + break; + } + + // Parse the JSON + Json::Value chunk_json; + if (is_anthropic(context->model)) { + AnthropicChunk ac(line); + if (ac.should_ignore) + continue; + ac.model = context->model; + if (ac.type == "message_start") { + context->id = ac.id; + } else { + ac.id = context->id; + } + chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; + } else { + chunk_json["data"] = line + "\n\n"; + } + Json::Reader reader; + + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), std::move(chunk_json)); + } + + return size * nmemb; +} + CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, const std::function& callback) { @@ -37,6 +168,11 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, api_key_template_.c_str()); } + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } + headers = curl_slist_append(headers, "Content-Type: application/json"); headers = curl_slist_append(headers, "Accept: text/event-stream"); headers = curl_slist_append(headers, "Cache-Control: no-cache"); @@ -45,7 +181,7 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( StreamContext context{ std::make_shared>( callback), - ""}; + "", "", config.model}; curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); @@ -249,6 +385,7 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, std::unique_lock lock(models_mutex_); models_[model] = std::move(model_config); } + CTL_DBG("LoadModelConfig successfully: " << model << ", " << yaml_path); return true; } catch (const YAML::Exception& e) { @@ -339,6 +476,7 @@ void RemoteEngine::LoadModel( status["is_stream"] = false; status["status_code"] = k200OK; callback(std::move(status), std::move(response)); + CTL_INF("Model loaded successfully: " << model); } void RemoteEngine::UnloadModel( diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 46e65c852..80922ac0d 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -16,59 +16,10 @@ namespace remote_engine { struct StreamContext { std::shared_ptr> callback; std::string buffer; + // Cache value for Anthropic + std::string id; + std::string model; }; - -static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, - void* userdata) { - auto* context = static_cast(userdata); - std::string chunk(ptr, size * nmemb); - - context->buffer += chunk; - - // Process complete lines - size_t pos; - while ((pos = context->buffer.find('\n')) != std::string::npos) { - std::string line = context->buffer.substr(0, pos); - context->buffer = context->buffer.substr(pos + 1); - - // Skip empty lines - if (line.empty() || line == "\r") - continue; - - // Remove "data: " prefix if present - // if (line.substr(0, 6) == "data: ") - // { - // line = line.substr(6); - // } - - // Skip [DONE] message - std::cout << line << std::endl; - if (line == "data: [DONE]") { - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = 200; - (*context->callback)(std::move(status), Json::Value()); - break; - } - - // Parse the JSON - Json::Value chunk_json; - chunk_json["data"] = line + "\n\n"; - Json::Reader reader; - - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = 200; - (*context->callback)(std::move(status), std::move(chunk_json)); - } - - return size * nmemb; -} - struct CurlResponse { std::string body; bool error{false}; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index a305f3c27..bce9cc76e 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -816,6 +816,7 @@ cpp::result EngineService::LoadEngine( } else { CTL_WRN("Method SetLogLevel is not supported yet"); } + CTL_INF("Loaded engine: " << engine_name); return {}; } diff --git a/engine/utils/logging_utils.h b/engine/utils/logging_utils.h index 2c5affcd4..191959dba 100644 --- a/engine/utils/logging_utils.h +++ b/engine/utils/logging_utils.h @@ -9,6 +9,8 @@ inline bool log_verbose = false; inline bool is_server = false; // Only use trantor log +#define CTL_TRC(msg) LOG_TRACE << msg; + #define CTL_DBG(msg) LOG_DEBUG << msg; #define CTL_INF(msg) LOG_INFO << msg; From 9444dcbba3b9ca1e3c9409a15c2d47a1f6c103ba Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 27 Nov 2024 07:28:21 +0700 Subject: [PATCH 25/33] feat: support anthropic --- engine/config/model_config.h | 66 +++++++- .../remote-engine/TemplateRenderer.cc | 1 + .../extensions/remote-engine/remote_engine.cc | 157 +++++++++++++++++- .../extensions/remote-engine/remote_engine.h | 55 +----- engine/services/engine_service.cc | 1 + engine/utils/logging_utils.h | 2 + 6 files changed, 228 insertions(+), 54 deletions(-) diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 1ff127329..13a4f58a9 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -15,6 +15,47 @@ #include "yaml-cpp/yaml.h" namespace config { +namespace { +const std::string kOpenAITransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kOpenAITransformRespTemplate = + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == \"id\" or key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"object\" or key == \"usage\" -%} {%- if not first -%},{%- endif -%} \"{{ key }}\": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; +const std::string kAnthropicTransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"system\" or key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kAnthropicTransformRespTemplate = R"({ + "id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [ + { + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" + } + ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "system_fingerprint": "fp_6b68a8204b" + })"; +} // namespace + struct RemoteModelConfig { std::string model; std::string api_key_template; @@ -38,14 +79,37 @@ struct RemoteModelConfig { json.get("api_key_template", api_key_template).asString(); engine = json.get("engine", engine).asString(); version = json.get("version", version).asString(); - created = json.get("created", static_cast(created)).asUInt64(); + created = + json.get("created", static_cast(created)).asUInt64(); object = json.get("object", object).asString(); owned_by = json.get("owned_by", owned_by).asString(); // Load JSON object fields directly inference_params = json.get("inference_params", inference_params); TransformReq = json.get("TransformReq", TransformReq); + // Use default template if it is empty, currently we only support 2 remote engines + auto is_anthropic = [](const std::string& model) { + return model.find("claude") != std::string::npos; + }; + if (TransformReq["chat_completions"]["template"].isNull()) { + if (is_anthropic(model)) { + TransformReq["chat_completions"]["template"] = + kAnthropicTransformReqTemplate; + } else { + TransformReq["chat_completions"]["template"] = + kOpenAITransformReqTemplate; + } + } TransformResp = json.get("TransformResp", TransformResp); + if (TransformResp["chat_completions"]["template"].isNull()) { + if (is_anthropic(model)) { + TransformResp["chat_completions"]["template"] = + kAnthropicTransformRespTemplate; + } else { + TransformResp["chat_completions"]["template"] = + kOpenAITransformRespTemplate; + } + } metadata = json.get("metadata", metadata); } diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/TemplateRenderer.cc index 3c3a0ea00..a000e925b 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/TemplateRenderer.cc @@ -6,6 +6,7 @@ #include "TemplateRenderer.h" #include #include +#include "utils/logging_utils.h" namespace remote_engine { TemplateRenderer::TemplateRenderer() { // Configure Inja environment diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 374df3c55..9c2d769c4 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -3,6 +3,7 @@ #include #include #include +#include "utils/json_helper.h" #include "utils/logging_utils.h" namespace remote_engine { namespace { @@ -14,8 +15,138 @@ constexpr const int kFileLoggerOption = 0; bool is_anthropic(const std::string& model) { return model.find("claude") != std::string::npos; } + +struct AnthropicChunk { + std::string type; + std::string id; + int index; + std::string msg; + std::string model; + std::string stop_reason; + bool should_ignore = false; + + AnthropicChunk(const std::string& str) { + if (str.size() > 6) { + std::string s = str.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + type = root["type"].asString(); + if (type == "message_start") { + id = root["message"]["id"].asString(); + model = root["message"]["model"].asString(); + } else if (type == "content_block_delta") { + index = root["index"].asInt(); + if (root["delta"]["type"].asString() == "text_delta") { + msg = root["delta"]["text"].asString(); + } + } else if (type == "message_delta") { + stop_reason = root["delta"]["stop_reason"].asString(); + } else { + // ignore other messages + should_ignore = true; + } + } catch (const std::exception& e) { + should_ignore = true; + CTL_WRN("JSON parse error: " << e.what()); + } + } else { + should_ignore = true; + } + } + + std::string ToOpenAiFormatString() { + Json::Value root; + root["id"] = id; + root["object"] = "chat.completion.chunk"; + root["created"] = Json::Value(); + root["model"] = model; + root["system_fingerprint"] = "fp_e76890f0c3"; + Json::Value choices(Json::arrayValue); + Json::Value choice; + Json::Value content; + choice["index"] = 0; + content["content"] = msg; + if (type == "message_start") { + content["role"] = "assistant"; + content["refusal"] = Json::Value(); + } + choice["delta"] = content; + choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; + choices.append(choice); + root["choices"] = choices; + return "data: " + json_helper::DumpJsonString(root); + } +}; + } // namespace +size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* context = static_cast(userdata); + std::string chunk(ptr, size * nmemb); + + context->buffer += chunk; + + // Process complete lines + size_t pos; + while ((pos = context->buffer.find('\n')) != std::string::npos) { + std::string line = context->buffer.substr(0, pos); + context->buffer = context->buffer.substr(pos + 1); + CTL_TRC(line); + + // Skip empty lines + if (line.empty() || line == "\r" || + line.find("event:") != std::string::npos) + continue; + + // Remove "data: " prefix if present + // if (line.substr(0, 6) == "data: ") + // { + // line = line.substr(6); + // } + + // Skip [DONE] message + // std::cout << line << std::endl; + if (line == "data: [DONE]" || + line.find("message_stop") != std::string::npos) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), Json::Value()); + break; + } + + // Parse the JSON + Json::Value chunk_json; + if (is_anthropic(context->model)) { + AnthropicChunk ac(line); + if (ac.should_ignore) + continue; + ac.model = context->model; + if (ac.type == "message_start") { + context->id = ac.id; + } else { + ac.id = context->id; + } + chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; + } else { + chunk_json["data"] = line + "\n\n"; + } + Json::Reader reader; + + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), std::move(chunk_json)); + } + + return size * nmemb; +} + CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, const std::function& callback) { @@ -37,6 +168,11 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, api_key_template_.c_str()); } + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } + headers = curl_slist_append(headers, "Content-Type: application/json"); headers = curl_slist_append(headers, "Accept: text/event-stream"); headers = curl_slist_append(headers, "Cache-Control: no-cache"); @@ -45,7 +181,7 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( StreamContext context{ std::make_shared>( callback), - ""}; + "", "", config.model}; curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); @@ -249,6 +385,7 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, std::unique_lock lock(models_mutex_); models_[model] = std::move(model_config); } + CTL_DBG("LoadModelConfig successfully: " << model << ", " << yaml_path); return true; } catch (const YAML::Exception& e) { @@ -339,6 +476,7 @@ void RemoteEngine::LoadModel( status["is_stream"] = false; status["status_code"] = k200OK; callback(std::move(status), std::move(response)); + CTL_INF("Model loaded successfully: " << model); } void RemoteEngine::UnloadModel( @@ -432,6 +570,23 @@ void RemoteEngine::HandleChatCompletion( std::string(e.what())); } + // Parse system for anthropic + if (is_anthropic(model)) { + bool has_system = false; + Json::Value msgs(Json::arrayValue); + for (auto& kv : (*json_body)["messages"]) { + if (kv["role"].asString() == "system") { + (*json_body)["system"] = kv["content"].asString(); + has_system = true; + } else { + msgs.append(kv); + } + } + if (has_system) { + (*json_body)["messages"] = msgs; + } + } + // Render with error handling try { result = renderer_.render(template_str, *json_body); diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 46e65c852..80922ac0d 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -16,59 +16,10 @@ namespace remote_engine { struct StreamContext { std::shared_ptr> callback; std::string buffer; + // Cache value for Anthropic + std::string id; + std::string model; }; - -static size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, - void* userdata) { - auto* context = static_cast(userdata); - std::string chunk(ptr, size * nmemb); - - context->buffer += chunk; - - // Process complete lines - size_t pos; - while ((pos = context->buffer.find('\n')) != std::string::npos) { - std::string line = context->buffer.substr(0, pos); - context->buffer = context->buffer.substr(pos + 1); - - // Skip empty lines - if (line.empty() || line == "\r") - continue; - - // Remove "data: " prefix if present - // if (line.substr(0, 6) == "data: ") - // { - // line = line.substr(6); - // } - - // Skip [DONE] message - std::cout << line << std::endl; - if (line == "data: [DONE]") { - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = 200; - (*context->callback)(std::move(status), Json::Value()); - break; - } - - // Parse the JSON - Json::Value chunk_json; - chunk_json["data"] = line + "\n\n"; - Json::Reader reader; - - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = 200; - (*context->callback)(std::move(status), std::move(chunk_json)); - } - - return size * nmemb; -} - struct CurlResponse { std::string body; bool error{false}; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index a305f3c27..bce9cc76e 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -816,6 +816,7 @@ cpp::result EngineService::LoadEngine( } else { CTL_WRN("Method SetLogLevel is not supported yet"); } + CTL_INF("Loaded engine: " << engine_name); return {}; } diff --git a/engine/utils/logging_utils.h b/engine/utils/logging_utils.h index 2c5affcd4..191959dba 100644 --- a/engine/utils/logging_utils.h +++ b/engine/utils/logging_utils.h @@ -9,6 +9,8 @@ inline bool log_verbose = false; inline bool is_server = false; // Only use trantor log +#define CTL_TRC(msg) LOG_TRACE << msg; + #define CTL_DBG(msg) LOG_DEBUG << msg; #define CTL_INF(msg) LOG_INFO << msg; From 1b30777c0bb6fadf7e7dc7586528beab67e450d4 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 29 Nov 2024 07:33:25 +0700 Subject: [PATCH 26/33] chore: rename --- engine/CMakeLists.txt | 2 +- engine/cli/CMakeLists.txt | 2 +- .../extensions/remote-engine/remote_engine.cc | 4 ++-- .../extensions/remote-engine/remote_engine.h | 2 +- ...mplateRenderer.cc => template_renderer.cc} | 22 +++++++++---------- ...TemplateRenderer.h => template_renderer.h} | 8 +++---- 6 files changed, 20 insertions(+), 20 deletions(-) rename engine/extensions/remote-engine/{TemplateRenderer.cc => template_renderer.cc} (86%) rename engine/extensions/remote-engine/{TemplateRenderer.h => template_renderer.h} (73%) diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index f3b9695a4..6b2ede2fd 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -144,7 +144,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/TemplateRenderer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index dd9c3764c..ae2f16f1d 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -84,7 +84,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/TemplateRenderer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 9c2d769c4..894eeb009 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -589,7 +589,7 @@ void RemoteEngine::HandleChatCompletion( // Render with error handling try { - result = renderer_.render(template_str, *json_body); + result = renderer_.Render(template_str, *json_body); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + std::string(e.what())); @@ -663,7 +663,7 @@ void RemoteEngine::HandleChatCompletion( // Render with error handling try { - response_str = renderer_.render(template_str, response_json); + response_str = renderer_.Render(template_str, response_json); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + std::string(e.what())); diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 80922ac0d..80c703f8f 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -8,7 +8,7 @@ #include #include #include "cortex-common/EngineI.h" -#include "extensions/remote-engine/TemplateRenderer.h" +#include "extensions/remote-engine/template_renderer.h" #include "utils/file_logger.h" // Helper for CURL response diff --git a/engine/extensions/remote-engine/TemplateRenderer.cc b/engine/extensions/remote-engine/template_renderer.cc similarity index 86% rename from engine/extensions/remote-engine/TemplateRenderer.cc rename to engine/extensions/remote-engine/template_renderer.cc index a000e925b..15514d17c 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.cc +++ b/engine/extensions/remote-engine/template_renderer.cc @@ -3,7 +3,7 @@ #undef min #undef max #endif -#include "TemplateRenderer.h" +#include "template_renderer.h" #include #include #include "utils/logging_utils.h" @@ -28,11 +28,11 @@ TemplateRenderer::TemplateRenderer() { }); } -std::string TemplateRenderer::render(const std::string& tmpl, +std::string TemplateRenderer::Render(const std::string& tmpl, const Json::Value& data) { try { // Convert Json::Value to nlohmann::json - auto json_data = convertJsonValue(data); + auto json_data = ConvertJsonValue(data); // Create the input data structure expected by the template nlohmann::json template_data; @@ -62,7 +62,7 @@ std::string TemplateRenderer::render(const std::string& tmpl, } } -nlohmann::json TemplateRenderer::convertJsonValue(const Json::Value& input) { +nlohmann::json TemplateRenderer::ConvertJsonValue(const Json::Value& input) { if (input.isNull()) { return nullptr; } else if (input.isBool()) { @@ -78,20 +78,20 @@ nlohmann::json TemplateRenderer::convertJsonValue(const Json::Value& input) { } else if (input.isArray()) { nlohmann::json arr = nlohmann::json::array(); for (const auto& element : input) { - arr.push_back(convertJsonValue(element)); + arr.push_back(ConvertJsonValue(element)); } return arr; } else if (input.isObject()) { nlohmann::json obj = nlohmann::json::object(); for (const auto& key : input.getMemberNames()) { - obj[key] = convertJsonValue(input[key]); + obj[key] = ConvertJsonValue(input[key]); } return obj; } return nullptr; } -Json::Value TemplateRenderer::convertNlohmannJson(const nlohmann::json& input) { +Json::Value TemplateRenderer::ConvertNlohmannJson(const nlohmann::json& input) { if (input.is_null()) { return Json::Value(); } else if (input.is_boolean()) { @@ -107,24 +107,24 @@ Json::Value TemplateRenderer::convertNlohmannJson(const nlohmann::json& input) { } else if (input.is_array()) { Json::Value arr(Json::arrayValue); for (const auto& element : input) { - arr.append(convertNlohmannJson(element)); + arr.append(ConvertNlohmannJson(element)); } return arr; } else if (input.is_object()) { Json::Value obj(Json::objectValue); for (auto it = input.begin(); it != input.end(); ++it) { - obj[it.key()] = convertNlohmannJson(it.value()); + obj[it.key()] = ConvertNlohmannJson(it.value()); } return obj; } return Json::Value(); } -std::string TemplateRenderer::renderFile(const std::string& template_path, +std::string TemplateRenderer::RenderFile(const std::string& template_path, const Json::Value& data) { try { // Convert Json::Value to nlohmann::json - auto json_data = convertJsonValue(data); + auto json_data = ConvertJsonValue(data); // Load and render template return env_.render_file(template_path, json_data); diff --git a/engine/extensions/remote-engine/TemplateRenderer.h b/engine/extensions/remote-engine/template_renderer.h similarity index 73% rename from engine/extensions/remote-engine/TemplateRenderer.h rename to engine/extensions/remote-engine/template_renderer.h index 7f2f7fd88..f59e7cc93 100644 --- a/engine/extensions/remote-engine/TemplateRenderer.h +++ b/engine/extensions/remote-engine/template_renderer.h @@ -21,16 +21,16 @@ class TemplateRenderer { ~TemplateRenderer() = default; // Convert Json::Value to nlohmann::json - static nlohmann::json convertJsonValue(const Json::Value& input); + static nlohmann::json ConvertJsonValue(const Json::Value& input); // Convert nlohmann::json to Json::Value - static Json::Value convertNlohmannJson(const nlohmann::json& input); + static Json::Value ConvertNlohmannJson(const nlohmann::json& input); // Render template with data - std::string render(const std::string& tmpl, const Json::Value& data); + std::string Render(const std::string& tmpl, const Json::Value& data); // Load template from file and render - std::string renderFile(const std::string& template_path, + std::string RenderFile(const std::string& template_path, const Json::Value& data); private: From e3371a01979b079195652a3e0db8c206232b133e Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 29 Nov 2024 09:22:42 +0700 Subject: [PATCH 27/33] chore: cleanup and fix unit tests --- engine/controllers/models.cc | 8 +- engine/database/engines.cc | 313 ++++++++++------------- engine/database/models.cc | 114 ++++----- engine/database/models.h | 10 +- engine/migrations/schema_version.h | 2 +- engine/migrations/v1/migration.h | 164 ++++++++++++ engine/test/components/test_models_db.cc | 26 +- 7 files changed, 381 insertions(+), 256 deletions(-) create mode 100644 engine/migrations/v1/migration.h diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index c1986cf91..d12a4f0a7 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -378,9 +378,9 @@ void Models::ImportModel( auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); cortex::db::ModelEntry model_entry{ + modelHandle, "", "", yaml_rel_path.string(), modelHandle, "local", "imported", cortex::db::ModelStatus::Downloaded, - "", "", "", yaml_rel_path.string(), - modelHandle}; + ""}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); @@ -639,9 +639,9 @@ void Models::AddRemoteModel( fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); // TODO: remove hardcode "openai" when engine is finish cortex::db::ModelEntry model_entry{ + model_handle, "", "", yaml_rel_path.string(), model_handle, "remote", "imported", cortex::db::ModelStatus::Remote, - "openai", "", "", yaml_rel_path.string(), - model_handle}; + "openai"}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); if (modellist_utils_obj.AddModelEntry(model_entry).value()) { diff --git a/engine/database/engines.cc b/engine/database/engines.cc index 5fa71bbf2..a4d13ef79 100644 --- a/engine/database/engines.cc +++ b/engine/database/engines.cc @@ -4,203 +4,170 @@ namespace cortex::db { -void CreateTable(SQLite::Database& db) { - db.exec( - "CREATE TABLE IF NOT EXISTS engines (" - "id INTEGER PRIMARY KEY AUTOINCREMENT," - "engine_name TEXT," - "type TEXT," - "api_key TEXT," - "url TEXT," - "version TEXT," - "variant TEXT," - "status TEXT," - "metadata TEXT," - "date_created TEXT DEFAULT CURRENT_TIMESTAMP," - "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," - "UNIQUE(engine_name, variant));"); -} +void CreateTable(SQLite::Database& db) {} Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { - CreateTable(db_); + CreateTable(db_); } Engines::Engines(SQLite::Database& db) : db_(db) { - CreateTable(db_); + CreateTable(db_); } Engines::~Engines() {} -std::optional Engines::UpsertEngine(const std::string& engine_name, - const std::string& type, - const std::string& api_key, - const std::string& url, - const std::string& version, - const std::string& variant, - const std::string& status, - const std::string& metadata) { - try { - SQLite::Statement query(db_, - "INSERT INTO engines (engine_name, type, api_key, url, version, variant, status, metadata) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?) " - "ON CONFLICT(engine_name, variant) DO UPDATE SET " - "type = excluded.type, " - "api_key = excluded.api_key, " - "url = excluded.url, " - "version = excluded.version, " - "status = excluded.status, " - "metadata = excluded.metadata, " - "date_updated = CURRENT_TIMESTAMP " - "RETURNING id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated;"); - - query.bind(1, engine_name); - query.bind(2, type); - query.bind(3, api_key); - query.bind(4, url); - query.bind(5, version); - query.bind(6, variant); - query.bind(7, status); - query.bind(8, metadata); - - if (query.executeStep()) { - return EngineEntry{ - query.getColumn(0).getInt(), - query.getColumn(1).getString(), - query.getColumn(2).getString(), - query.getColumn(3).getString(), - query.getColumn(4).getString(), - query.getColumn(5).getString(), - query.getColumn(6).getString(), - query.getColumn(7).getString(), - query.getColumn(8).getString(), - query.getColumn(9).getString(), - query.getColumn(10).getString() - }; - } else { - return std::nullopt; - } - } catch (const std::exception& e) { - return std::nullopt; +std::optional Engines::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + try { + SQLite::Statement query( + db_, + "INSERT INTO engines (engine_name, type, api_key, url, version, " + "variant, status, metadata) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?) " + "ON CONFLICT(engine_name, variant) DO UPDATE SET " + "type = excluded.type, " + "api_key = excluded.api_key, " + "url = excluded.url, " + "version = excluded.version, " + "status = excluded.status, " + "metadata = excluded.metadata, " + "date_updated = CURRENT_TIMESTAMP " + "RETURNING id, engine_name, type, api_key, url, version, variant, " + "status, metadata, date_created, date_updated;"); + + query.bind(1, engine_name); + query.bind(2, type); + query.bind(3, api_key); + query.bind(4, url); + query.bind(5, version); + query.bind(6, variant); + query.bind(7, status); + query.bind(8, metadata); + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; } + } catch (const std::exception& e) { + return std::nullopt; + } } std::optional> Engines::GetEngines() const { - try { - SQLite::Statement query(db_, - "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " - "FROM engines " - "WHERE status = 'Default' " - "ORDER BY date_updated DESC"); - - std::vector engines; - while (query.executeStep()) { - engines.push_back(EngineEntry{ - query.getColumn(0).getInt(), - query.getColumn(1).getString(), - query.getColumn(2).getString(), - query.getColumn(3).getString(), - query.getColumn(4).getString(), - query.getColumn(5).getString(), - query.getColumn(6).getString(), - query.getColumn(7).getString(), - query.getColumn(8).getString(), - query.getColumn(9).getString(), - query.getColumn(10).getString() - }); - } - - return engines; - } catch (const std::exception& e) { - return std::nullopt; + try { + SQLite::Statement query( + db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE status = 'Default' " + "ORDER BY date_updated DESC"); + + std::vector engines; + while (query.executeStep()) { + engines.push_back(EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}); } + + return engines; + } catch (const std::exception& e) { + return std::nullopt; + } } std::optional Engines::GetEngineById(int id) const { - try { - SQLite::Statement query(db_, - "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " - "FROM engines " - "WHERE id = ? AND status = 'Default' " - "ORDER BY date_updated DESC LIMIT 1"); - - query.bind(1, id); - - if (query.executeStep()) { - return EngineEntry{ - query.getColumn(0).getInt(), - query.getColumn(1).getString(), - query.getColumn(2).getString(), - query.getColumn(3).getString(), - query.getColumn(4).getString(), - query.getColumn(5).getString(), - query.getColumn(6).getString(), - query.getColumn(7).getString(), - query.getColumn(8).getString(), - query.getColumn(9).getString(), - query.getColumn(10).getString() - }; - } else { - return std::nullopt; - } - } catch (const std::exception& e) { - return std::nullopt; + try { + SQLite::Statement query( + db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE id = ? AND status = 'Default' " + "ORDER BY date_updated DESC LIMIT 1"); + + query.bind(1, id); + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; } + } catch (const std::exception& e) { + return std::nullopt; + } } -std::optional Engines::GetEngineByNameAndVariant(const std::string& engine_name, const std::optional variant) const { - try { - std::string queryStr = - "SELECT id, engine_name, type, api_key, url, version, variant, status, metadata, date_created, date_updated " - "FROM engines " - "WHERE engine_name = ? AND status = 'Default' "; - - if (variant) { - queryStr += "AND variant = ? "; - } - - queryStr += "ORDER BY date_updated DESC LIMIT 1"; - - SQLite::Statement query(db_, queryStr); - - query.bind(1, engine_name); - - if (variant) { - query.bind(2, variant.value()); - } - - if (query.executeStep()) { - return EngineEntry{ - query.getColumn(0).getInt(), - query.getColumn(1).getString(), - query.getColumn(2).getString(), - query.getColumn(3).getString(), - query.getColumn(4).getString(), - query.getColumn(5).getString(), - query.getColumn(6).getString(), - query.getColumn(7).getString(), - query.getColumn(8).getString(), - query.getColumn(9).getString(), - query.getColumn(10).getString() - }; - } else { - return std::nullopt; - } - } catch (const std::exception& e) { - return std::nullopt; +std::optional Engines::GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant) const { + try { + std::string queryStr = + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE engine_name = ? AND status = 'Default' "; + + if (variant) { + queryStr += "AND variant = ? "; + } + + queryStr += "ORDER BY date_updated DESC LIMIT 1"; + + SQLite::Statement query(db_, queryStr); + + query.bind(1, engine_name); + + if (variant) { + query.bind(2, variant.value()); } + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } } std::optional Engines::DeleteEngineById(int id) { - try { - SQLite::Statement query(db_, - "DELETE FROM engines WHERE id = ?"); - - query.bind(1, id); - query.exec(); - return std::nullopt; - } catch (const std::exception& e) { - return std::string("Failed to delete engine: ") + e.what(); - } + try { + SQLite::Statement query(db_, "DELETE FROM engines WHERE id = ?"); + + query.bind(1, id); + query.exec(); + return std::nullopt; + } catch (const std::exception& e) { + return std::string("Failed to delete engine: ") + e.what(); + } } } // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.cc b/engine/database/models.cc index f09633c1f..860c60f3f 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -7,19 +7,7 @@ namespace cortex::db { -Models::Models() : db_(cortex::db::Database::GetInstance().db()) { - // db_.exec( - // "CREATE TABLE IF NOT EXISTS models (" - // "model_id TEXT PRIMARY KEY," - // "model_format TEXT," - // "model_source TEXT," - // "status TEXT," - // "engine TEXT," - // "author_repo_id TEXT," - // "branch_name TEXT," - // "path_to_model_yaml TEXT," - // "model_alias TEXT);"); -} +Models::Models() : db_(cortex::db::Database::GetInstance().db()) {} Models::~Models() {} @@ -33,7 +21,6 @@ std::string Models::StatusToString(ModelStatus status) const { return "undownloaded"; } return "unknown"; - } Models::Models(SQLite::Database& db) : db_(db) {} @@ -76,21 +63,21 @@ cpp::result, std::string> Models::LoadModelListNoLock() try { std::vector entries; SQLite::Statement query(db_, - "SELECT model_id, model_format, model_source, " - "status, engine, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias FROM models"); + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine FROM models"); while (query.executeStep()) { ModelEntry entry; entry.model = query.getColumn(0).getString(); - entry.model_format = query.getColumn(1).getString(); - entry.model_source = query.getColumn(2).getString(); - entry.status = StringToStatus(query.getColumn(3).getString()); - entry.engine = query.getColumn(4).getString(); - entry.author_repo_id = query.getColumn(5).getString(); - entry.branch_name = query.getColumn(6).getString(); - entry.path_to_model_yaml = query.getColumn(7).getString(); - entry.model_alias = query.getColumn(8).getString(); + entry.author_repo_id = query.getColumn(1).getString(); + entry.branch_name = query.getColumn(2).getString(); + entry.path_to_model_yaml = query.getColumn(3).getString(); + entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(5).getString(); + entry.model_source = query.getColumn(6).getString(); + entry.status = StringToStatus(query.getColumn(7).getString()); + entry.engine = query.getColumn(8).getString(); entries.push_back(entry); } return entries; @@ -164,9 +151,9 @@ cpp::result Models::GetModelInfo( const std::string& identifier) const { try { SQLite::Statement query(db_, - "SELECT model_id, model_format, model_source, " - "status, engine, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias FROM models " + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine FROM models " "WHERE model_id = ? OR model_alias = ?"); query.bind(1, identifier); @@ -174,14 +161,14 @@ cpp::result Models::GetModelInfo( if (query.executeStep()) { ModelEntry entry; entry.model = query.getColumn(0).getString(); - entry.model_format = query.getColumn(1).getString(); - entry.model_source = query.getColumn(2).getString(); - entry.status = StringToStatus(query.getColumn(3).getString()); - entry.engine = query.getColumn(4).getString(); - entry.author_repo_id = query.getColumn(5).getString(); - entry.branch_name = query.getColumn(6).getString(); - entry.path_to_model_yaml = query.getColumn(7).getString(); - entry.model_alias = query.getColumn(8).getString(); + entry.author_repo_id = query.getColumn(1).getString(); + entry.branch_name = query.getColumn(2).getString(); + entry.path_to_model_yaml = query.getColumn(3).getString(); + entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(5).getString(); + entry.model_source = query.getColumn(6).getString(); + entry.status = StringToStatus(query.getColumn(7).getString()); + entry.engine = query.getColumn(8).getString(); return entry; } else { return cpp::fail("Model not found: " + identifier); @@ -193,14 +180,14 @@ cpp::result Models::GetModelInfo( void Models::PrintModelInfo(const ModelEntry& entry) const { LOG_INFO << "Model ID: " << entry.model; - LOG_INFO << "Model Format: " << entry.model_format; - LOG_INFO << "Model Source: " << entry.model_source; - LOG_INFO << "Status: " << StatusToString(entry.status); - LOG_INFO << "Engine: " << entry.engine; LOG_INFO << "Author/Repo ID: " << entry.author_repo_id; LOG_INFO << "Branch Name: " << entry.branch_name; LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; LOG_INFO << "Model Alias: " << entry.model_alias; + LOG_INFO << "Model Format: " << entry.model_format; + LOG_INFO << "Model Source: " << entry.model_source; + LOG_INFO << "Status: " << StatusToString(entry.status); + LOG_INFO << "Engine: " << entry.engine; } cpp::result Models::AddModelEntry(ModelEntry new_entry, @@ -221,18 +208,18 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, SQLite::Statement insert( db_, - "INSERT INTO models (model_id, model_format, model_source, status, " - "engine, author_repo_id, branch_name, path_to_model_yaml, model_alias) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); + "INSERT INTO models (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, model_source, " + "status, engine) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.bind(1, new_entry.model); - insert.bind(2, new_entry.model_format); - insert.bind(3, new_entry.model_source); - insert.bind(4, StatusToString(new_entry.status)); - insert.bind(5, new_entry.engine); - insert.bind(6, new_entry.author_repo_id); - insert.bind(7, new_entry.branch_name); - insert.bind(8, new_entry.path_to_model_yaml); - insert.bind(9, new_entry.model_alias); + insert.bind(2, new_entry.author_repo_id); + insert.bind(3, new_entry.branch_name); + insert.bind(4, new_entry.path_to_model_yaml); + insert.bind(5, new_entry.model_alias); + insert.bind(6, new_entry.model_format); + insert.bind(7, new_entry.model_source); + insert.bind(8, StatusToString(new_entry.status)); + insert.bind(9, new_entry.engine); insert.exec(); return true; @@ -250,19 +237,18 @@ cpp::result Models::UpdateModelEntry( return cpp::fail("Model not found: " + identifier); } try { - SQLite::Statement upd(db_, - "UPDATE models " - "SET model_format = ?, model_source = ?, status = ?, " - "engine = ?, author_repo_id = ?, branch_name = ?, " - "path_to_model_yaml = ? " - "WHERE model_id = ? OR model_alias = ?"); - upd.bind(1, updated_entry.model_format); - upd.bind(2, updated_entry.model_source); - upd.bind(3, StatusToString(updated_entry.status)); - upd.bind(4, updated_entry.engine); - upd.bind(5, updated_entry.author_repo_id); - upd.bind(6, updated_entry.branch_name); - upd.bind(7, updated_entry.path_to_model_yaml); + SQLite::Statement upd( + db_, + "UPDATE models SET author_repo_id = ?, branch_name = ?, " + "path_to_model_yaml = ?, model_format = ?, model_source = ?, status = " + "?, engine = ? WHERE model_id = ? OR model_alias = ?"); + upd.bind(1, updated_entry.author_repo_id); + upd.bind(2, updated_entry.branch_name); + upd.bind(3, updated_entry.path_to_model_yaml); + upd.bind(4, updated_entry.model_format); + upd.bind(5, updated_entry.model_source); + upd.bind(6, StatusToString(updated_entry.status)); + upd.bind(7, updated_entry.engine); upd.bind(8, identifier); upd.bind(9, identifier); return upd.exec() == 1; diff --git a/engine/database/models.h b/engine/database/models.h index ecab1cd12..dd6e2a5a1 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -15,15 +15,15 @@ enum class ModelStatus { }; struct ModelEntry { - std::string model; - std::string model_format; - std::string model_source; - ModelStatus status; - std::string engine; + std::string model; std::string author_repo_id; std::string branch_name; std::string path_to_model_yaml; std::string model_alias; + std::string model_format; + std::string model_source; + ModelStatus status; + std::string engine; }; class Models { diff --git a/engine/migrations/schema_version.h b/engine/migrations/schema_version.h index 7cfccf27a..1e64110e3 100644 --- a/engine/migrations/schema_version.h +++ b/engine/migrations/schema_version.h @@ -1,4 +1,4 @@ #pragma once //Track the current schema version -#define SCHEMA_VERSION 0 \ No newline at end of file +#define SCHEMA_VERSION 1 \ No newline at end of file diff --git a/engine/migrations/v1/migration.h b/engine/migrations/v1/migration.h new file mode 100644 index 000000000..35472904c --- /dev/null +++ b/engine/migrations/v1/migration.h @@ -0,0 +1,164 @@ +#pragma once +#include +#include +#include +#include "utils/file_manager_utils.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace cortex::migr::v1 { +// Data folder +namespace fmu = file_manager_utils; + +// cortexcpp +// |__ models +// | |__ cortex.so +// | |__ tinyllama +// | |__ gguf +// |__ engines +// | |__ cortex.llamacpp +// | |__ deps +// | |__ windows-amd64-avx +// |__ logs +// +inline cpp::result MigrateFolderStructureUp() { + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "models")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "models"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "engines")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "engines"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "logs")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "logs"); + } + + return true; +} + +inline cpp::result MigrateFolderStructureDown() { + // CTL_INF("Folder structure already up to date!"); + return true; +} + +// Database +inline cpp::result MigrateDBUp(SQLite::Database& db) { + try { + db.exec( + "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY " + "KEY);"); + + // models + { + // Check if the table exists + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + + if (table_exists) { + // Alter existing table + db.exec("ALTER TABLE models ADD COLUMN model_format TEXT"); + db.exec("ALTER TABLE models ADD COLUMN model_source TEXT"); + db.exec("ALTER TABLE models ADD COLUMN status TEXT"); + db.exec("ALTER TABLE models ADD COLUMN engine TEXT"); + } else { + // Create new table + db.exec( + "CREATE TABLE models (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); + } + } + + db.exec( + "CREATE TABLE IF NOT EXISTS hardware (" + "uuid TEXT PRIMARY KEY, " + "type TEXT NOT NULL, " + "hardware_id INTEGER NOT NULL, " + "software_id INTEGER NOT NULL, " + "activated INTEGER NOT NULL CHECK (activated IN (0, 1)));"); + + // engines + db.exec( + "CREATE TABLE IF NOT EXISTS engines (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "engine_name TEXT," + "type TEXT," + "api_key TEXT," + "url TEXT," + "version TEXT," + "variant TEXT," + "status TEXT," + "metadata TEXT," + "date_created TEXT DEFAULT CURRENT_TIMESTAMP," + "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," + "UNIQUE(engine_name, variant));"); + + // CTL_INF("Database migration up completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration up failed: " << e.what()); + return cpp::fail(e.what()); + } +}; + +inline cpp::result MigrateDBDown(SQLite::Database& db) { + try { + // models + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + if (table_exists) { + // Create a new table with the old schema + db.exec( + "CREATE TABLE models_old (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT" + ")"); + + // Copy data from the current table to the new table + db.exec( + "INSERT INTO models_old (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias) " + "SELECT model_id, author_repo_id, branch_name, path_to_model_yaml, " + "model_alias FROM models"); + + // Drop the current table + db.exec("DROP TABLE models"); + + // Rename the new table to the original name + db.exec("ALTER TABLE models_old RENAME TO models"); + } + } + + // hardware + { + // Do nothing + } + + // engines + db.exec("DROP TABLE IF EXISTS engines;"); + // CTL_INF("Migration down completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration down failed: " << e.what()); + return cpp::fail(e.what()); + } +} + +}; // namespace cortex::migr::v1 diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index 9ea9fbc72..ab0ea9f70 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -15,12 +15,17 @@ class ModelsTestSuite : public ::testing::Test { void SetUp() { try { db_.exec( - "CREATE TABLE IF NOT EXISTS models (" + "CREATE TABLE models (" "model_id TEXT PRIMARY KEY," "author_repo_id TEXT," "branch_name TEXT," "path_to_model_yaml TEXT," - "model_alias TEXT);"); + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); } catch (const std::exception& e) {} } @@ -35,15 +40,18 @@ class ModelsTestSuite : public ::testing::Test { cortex::db::Models model_list_; const cortex::db::ModelEntry kTestModel{ - "test_model_id", "test_format", "test_source", cortex::db::ModelStatus::Downloaded, "test_engine", - "test_author", "main", "/path/to/model.yaml", "test_alias"}; + "test_model_id", "test_author", + "main", "/path/to/model.yaml", + "test_alias", "test_format", + "test_source", cortex::db::ModelStatus::Downloaded, + "test_engine"}; }; TEST_F(ModelsTestSuite, TestAddModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); EXPECT_EQ(retrieved_model.value().model_format, kTestModel.model_format); @@ -59,7 +67,7 @@ TEST_F(ModelsTestSuite, TestGetModelInfo) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); auto model_by_id = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(model_by_id); + EXPECT_TRUE(model_by_id.has_value()); EXPECT_EQ(model_by_id.value().model, kTestModel.model); auto model_by_alias = model_list_.GetModelInfo("test_alias"); @@ -82,7 +90,7 @@ TEST_F(ModelsTestSuite, TestUpdateModelEntry) { model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().status, updated_model.status); // Clean up @@ -122,7 +130,7 @@ TEST_F(ModelsTestSuite, TestPersistence) { // Create a new ModelListUtils instance to test if it loads from file cortex::db::Models new_model_list(db_); auto retrieved_model = new_model_list.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); @@ -141,7 +149,7 @@ TEST_F(ModelsTestSuite, TestUpdateModelAlias) { EXPECT_TRUE( model_list_.UpdateModelAlias(kTestModel.model, kNewTestAlias).value()); auto updated_model = model_list_.GetModelInfo(kNewTestAlias); - EXPECT_TRUE(updated_model); + EXPECT_TRUE(updated_model.has_value()); EXPECT_EQ(updated_model.value().model_alias, kNewTestAlias); EXPECT_EQ(updated_model.value().model, kTestModel.model); From e14ee6e032d9a86cd645b8112271afdbef058259 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 29 Nov 2024 14:22:59 +0700 Subject: [PATCH 28/33] fix: issue with db --- engine/migrations/db_helper.h | 26 ++++++++++++++++++++++++++ engine/migrations/migration_helper.cc | 2 +- engine/migrations/migration_manager.cc | 25 ++++++++++++++++++++++++- engine/migrations/v1/migration.h | 11 ++++++----- 4 files changed, 57 insertions(+), 7 deletions(-) create mode 100644 engine/migrations/db_helper.h diff --git a/engine/migrations/db_helper.h b/engine/migrations/db_helper.h new file mode 100644 index 000000000..0990426bf --- /dev/null +++ b/engine/migrations/db_helper.h @@ -0,0 +1,26 @@ +#pragma once +#include + +namespace cortex::mgr { +#include +#include +#include +#include + +inline bool ColumnExists(SQLite::Database& db, const std::string& table_name, const std::string& column_name) { + try { + SQLite::Statement query(db, "SELECT " + column_name + " FROM " + table_name + " LIMIT 0"); + return true; + } catch (std::exception&) { + return false; + } +} + +inline void AddColumnIfNotExists(SQLite::Database& db, const std::string& table_name, + const std::string& column_name, const std::string& column_type) { + if (!ColumnExists(db, table_name, column_name)) { + std::string sql = "ALTER TABLE " + table_name + " ADD COLUMN " + column_name + " " + column_type; + db.exec(sql); + } +} +} \ No newline at end of file diff --git a/engine/migrations/migration_helper.cc b/engine/migrations/migration_helper.cc index f2b39d77e..e87cd8f00 100644 --- a/engine/migrations/migration_helper.cc +++ b/engine/migrations/migration_helper.cc @@ -8,7 +8,7 @@ cpp::result MigrationHelper::BackupDatabase( SQLite::Database src_db(src_db_path, SQLite::OPEN_READONLY); sqlite3* backup_db; - if (sqlite3_open16(backup_db_path.c_str(), &backup_db) != SQLITE_OK) { + if (sqlite3_open(backup_db_path.c_str(), &backup_db) != SQLITE_OK) { throw std::runtime_error("Failed to open backup database"); } diff --git a/engine/migrations/migration_manager.cc b/engine/migrations/migration_manager.cc index f4b4f8046..1e69a3118 100644 --- a/engine/migrations/migration_manager.cc +++ b/engine/migrations/migration_manager.cc @@ -4,6 +4,9 @@ #include "schema_version.h" #include "utils/file_manager_utils.h" #include "utils/scope_exit.h" +#include "utils/widechar_conv.h" +#include "v0/migration.h" +#include "v1/migration.h" namespace cortex::migr { @@ -40,7 +43,15 @@ cpp::result MigrationManager::Migrate() { if (std::filesystem::exists(fmu::GetCortexDataPath() / kCortexDb)) { auto src_db_path = (fmu::GetCortexDataPath() / kCortexDb); auto backup_db_path = (fmu::GetCortexDataPath() / kCortexDbBackup); - if (auto res = mgr_helper_.BackupDatabase(src_db_path, backup_db_path.string()); + std::cout << src_db_path.string() << std::endl; + std::cout << backup_db_path.string() << std::endl; +#if defined(_WIN32) + if (auto res = mgr_helper_.BackupDatabase( + src_db_path, cortex::wc::WstringToUtf8(backup_db_path.wstring())); +#else + if (auto res = + mgr_helper_.BackupDatabase(src_db_path, backup_db_path.string()); +#endif res.has_error()) { CTL_INF("Error: backup database failed!"); return res; @@ -133,6 +144,9 @@ cpp::result MigrationManager::DoUpFolderStructure( case 0: return v0::MigrateFolderStructureUp(); break; + case 1: + return v1::MigrateFolderStructureUp(); + break; default: return true; @@ -144,6 +158,9 @@ cpp::result MigrationManager::DoDownFolderStructure( case 0: return v0::MigrateFolderStructureDown(); break; + case 1: + return v1::MigrateFolderStructureDown(); + break; default: return true; @@ -177,6 +194,9 @@ cpp::result MigrationManager::DoUpDB(int version) { case 0: return v0::MigrateDBUp(db_); break; + case 1: + return v1::MigrateDBUp(db_); + break; default: return true; @@ -188,6 +208,9 @@ cpp::result MigrationManager::DoDownDB(int version) { case 0: return v0::MigrateDBDown(db_); break; + case 1: + return v1::MigrateDBDown(db_); + break; default: return true; diff --git a/engine/migrations/v1/migration.h b/engine/migrations/v1/migration.h index 35472904c..f9a8038e3 100644 --- a/engine/migrations/v1/migration.h +++ b/engine/migrations/v1/migration.h @@ -2,6 +2,7 @@ #include #include #include +#include "migrations/db_helper.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" #include "utils/result.hpp" @@ -59,10 +60,10 @@ inline cpp::result MigrateDBUp(SQLite::Database& db) { if (table_exists) { // Alter existing table - db.exec("ALTER TABLE models ADD COLUMN model_format TEXT"); - db.exec("ALTER TABLE models ADD COLUMN model_source TEXT"); - db.exec("ALTER TABLE models ADD COLUMN status TEXT"); - db.exec("ALTER TABLE models ADD COLUMN engine TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "model_format", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "model_source", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "status", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "engine", "TEXT"); } else { // Create new table db.exec( @@ -152,7 +153,7 @@ inline cpp::result MigrateDBDown(SQLite::Database& db) { } // engines - db.exec("DROP TABLE IF EXISTS engines;"); + db.exec("DROP TABLE IF EXISTS engines;"); // CTL_INF("Migration down completed successfully."); return true; } catch (const std::exception& e) { From cee2838cecd184c1d6bd5737066bccbcb365102b Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Sat, 30 Nov 2024 04:49:02 +0700 Subject: [PATCH 29/33] chore: refactor remote engine --- engine/CMakeLists.txt | 2 + engine/cli/CMakeLists.txt | 2 + engine/controllers/models.cc | 28 +++++++++ engine/controllers/models.h | 6 ++ engine/cortex-common/EngineI.h | 2 + .../remote-engine/anthropic_engine.cc | 62 +++++++++++++++++++ .../remote-engine/anthropic_engine.h | 13 ++++ .../extensions/remote-engine/openai_engine.cc | 54 ++++++++++++++++ .../extensions/remote-engine/openai_engine.h | 14 +++++ .../extensions/remote-engine/remote_engine.cc | 45 +++----------- .../extensions/remote-engine/remote_engine.h | 12 +++- engine/services/engine_service.cc | 30 +++++++-- engine/services/engine_service.h | 3 + engine/services/model_service.cc | 3 +- engine/utils/engine_constants.h | 2 + 15 files changed, 232 insertions(+), 46 deletions(-) create mode 100644 engine/extensions/remote-engine/anthropic_engine.cc create mode 100644 engine/extensions/remote-engine/anthropic_engine.h create mode 100644 engine/extensions/remote-engine/openai_engine.cc create mode 100644 engine/extensions/remote-engine/openai_engine.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 6b2ede2fd..3ad2590fe 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -144,6 +144,8 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index ae2f16f1d..45ee4143a 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -84,6 +84,8 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/openai_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index d12a4f0a7..1f51e2368 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -594,6 +594,34 @@ void Models::GetModelStatus( } } +void Models::GetRemoteModels( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& engine_id) { + if (!remote_engine::IsRemoteEngine(engine_id)) { + Json::Value ret; + ret["message"] = "Not a remote engine: " + engine_id; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->GetRemoteModels(engine_id); + + if (result.has_error()) { + Json::Value ret; + ret["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + } else { + auto resp = cortex_utils::CreateCortexHttpJsonResponse(result.value()); + resp->setStatusCode(k200OK); + callback(resp); + } +} + void Models::AddRemoteModel( const HttpRequestPtr& req, std::function&& callback) const { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 24adffd1c..3227c0999 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -22,6 +22,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::StopModel, "/stop", Options, Post); METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); METHOD_ADD(Models::AddRemoteModel, "/add", Options, Post); + METHOD_ADD(Models::GetRemoteModels, "/remote/{1}", Get); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Options, Post); ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); @@ -34,6 +35,7 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); + ADD_METHOD_TO(Models::GetRemoteModels, "/v1/remote/{1}", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, @@ -78,6 +80,10 @@ class Models : public drogon::HttpController { std::function&& callback, const std::string& model_id); + void GetRemoteModels(const HttpRequestPtr& req, + std::function&& callback, + const std::string& engine_id); + private: std::shared_ptr model_service_; std::shared_ptr engine_service_; diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index 95ce605de..51e19c124 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -37,4 +37,6 @@ class EngineI { virtual bool SetFileLogger(int max_log_lines, const std::string& log_path) = 0; virtual void SetLogLevel(trantor::Logger::LogLevel logLevel) = 0; + + virtual Json::Value GetRemoteModels() = 0; }; diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc new file mode 100644 index 000000000..847cba566 --- /dev/null +++ b/engine/extensions/remote-engine/anthropic_engine.cc @@ -0,0 +1,62 @@ +#include "anthropic_engine.h" +#include +#include +#include "utils/logging_utils.h" + +namespace remote_engine { +namespace { +constexpr const std::array kAnthropicModels = { + "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"}; +} +void AnthropicEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "anthropic"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); +} + +Json::Value AnthropicEngine::GetRemoteModels() { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + for (const auto& m : kAnthropicModels) { + Json::Value val; + val["id"] = std::string(m); + val["engine"] = "anthropic"; + val["created"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + CTL_INF("Remote models responded"); + return json_resp; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h new file mode 100644 index 000000000..bcd3dfaf7 --- /dev/null +++ b/engine/extensions/remote-engine/anthropic_engine.h @@ -0,0 +1,13 @@ +#pragma once +#include "remote_engine.h" + +namespace remote_engine { + class AnthropicEngine: public RemoteEngine { +public: + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; + }; +} \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.cc b/engine/extensions/remote-engine/openai_engine.cc new file mode 100644 index 000000000..7c7d70385 --- /dev/null +++ b/engine/extensions/remote-engine/openai_engine.cc @@ -0,0 +1,54 @@ +#include "openai_engine.h" +#include "utils/logging_utils.h" + +namespace remote_engine { + +void OpenAiEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "openai"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); +} + +Json::Value OpenAiEngine::GetRemoteModels() { + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.h b/engine/extensions/remote-engine/openai_engine.h new file mode 100644 index 000000000..61dc68f0c --- /dev/null +++ b/engine/extensions/remote-engine/openai_engine.h @@ -0,0 +1,14 @@ +#pragma once + +#include "remote_engine.h" + +namespace remote_engine { +class OpenAiEngine : public RemoteEngine { + public: + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; +}; +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 894eeb009..d9aea2f41 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -242,7 +242,7 @@ RemoteEngine::~RemoteEngine() { RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( const std::string& model) { - std::shared_lock lock(models_mutex_); + std::shared_lock lock(models_mtx_); auto it = models_.find(model); if (it != models_.end()) { return &it->second; @@ -382,7 +382,7 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, // Thread-safe update of models map { - std::unique_lock lock(models_mutex_); + std::unique_lock lock(models_mtx_); models_[model] = std::move(model_config); } CTL_DBG("LoadModelConfig successfully: " << model << ", " << yaml_path); @@ -398,39 +398,7 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, void RemoteEngine::GetModels( std::shared_ptr json_body, std::function&& callback) { - - auto response = MakeGetModelsRequest(); - if (response.error) { - Json::Value status; - status["is_done"] = true; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k400BadRequest; - Json::Value error; - error["error"] = response.error_message; - callback(std::move(status), std::move(error)); - return; - } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value status; - status["is_done"] = true; - status["has_error"] = true; - status["is_stream"] = false; - status["status_code"] = k500InternalServerError; - Json::Value error; - error["error"] = "Failed to parse response"; - callback(std::move(status), std::move(error)); - return; - } - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = k200OK; - - callback(std::move(status), std::move(response_json)); + CTL_WRN("Not implemented yet!"); } void RemoteEngine::LoadModel( @@ -497,7 +465,7 @@ void RemoteEngine::UnloadModel( const std::string& model = (*json_body)["model"].asString(); { - std::unique_lock lock(models_mutex_); + std::unique_lock lock(models_mtx_); models_.erase(model); } @@ -775,6 +743,11 @@ void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { trantor::Logger::setLogLevel(log_level); } +Json::Value RemoteEngine::GetRemoteModels() { + CTL_WRN("Not implemented yet!"); + return {}; +} + extern "C" { EngineI* get_engine() { return new RemoteEngine(); diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 80c703f8f..153ec6408 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -9,10 +9,15 @@ #include #include "cortex-common/EngineI.h" #include "extensions/remote-engine/template_renderer.h" +#include "utils/engine_constants.h" #include "utils/file_logger.h" // Helper for CURL response namespace remote_engine { +inline bool IsRemoteEngine(std::string_view e) { + return e == kAnthropicEngine || e == kOpenAiEngine; +} + struct StreamContext { std::shared_ptr> callback; std::string buffer; @@ -27,7 +32,7 @@ struct CurlResponse { }; class RemoteEngine : public EngineI { - private: + protected: // Model configuration struct ModelConfig { std::string model; @@ -40,7 +45,7 @@ class RemoteEngine : public EngineI { }; // Thread-safe model config storage - mutable std::shared_mutex models_mutex_; + mutable std::shared_mutex models_mtx_; std::unordered_map models_; TemplateRenderer renderer_; Json::Value metadata_; @@ -63,7 +68,7 @@ class RemoteEngine : public EngineI { public: RemoteEngine(); - ~RemoteEngine(); + virtual ~RemoteEngine(); // Main interface implementations void GetModels( @@ -93,6 +98,7 @@ class RemoteEngine : public EngineI { bool IsSupported(const std::string& feature) override; bool SetFileLogger(int max_log_lines, const std::string& log_path) override; void SetLogLevel(trantor::Logger::LogLevel logLevel) override; + Json::Value GetRemoteModels() override; }; } // namespace remote_engine \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 0f835d4de..d521c480d 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -5,7 +5,8 @@ #include #include "algorithm" #include "database/engines.h" -#include "extensions/remote-engine/remote_engine.h" +#include "extensions/remote-engine/anthropic_engine.h" +#include "extensions/remote-engine/openai_engine.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" @@ -680,14 +681,18 @@ cpp::result EngineService::LoadEngine( } // Check for remote engine - if (engine_name != kLlamaEngine && engine_name != kOnnxEngine && - engine_name != kTrtLlmEngine) { + if (remote_engine::IsRemoteEngine(engine_name)) { auto exist_engine = GetEngineByNameAndVariant(engine_name); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - engines_[engine_name].engine = new remote_engine::RemoteEngine(); + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + auto& en = std::get(engines_[ne].engine); auto config = file_manager_utils::GetCortexConfig(); if (en->IsSupported("SetFileLogger")) { @@ -904,7 +909,7 @@ cpp::result EngineService::IsEngineReady( auto ne = NormalizeEngine(engine); // Check for remote engine - if (engine != kLlamaRepo && engine != kTrtLlmRepo && engine != kOnnxRepo) { + if (remote_engine::IsRemoteEngine(engine)) { auto exist_engine = GetEngineByNameAndVariant(engine); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine + "' is not installed"); @@ -1066,4 +1071,19 @@ std::string EngineService::DeleteEngine(int id) { } else { return ""; } +} + +cpp::result EngineService::GetRemoteModels( + const std::string& engine_name) { + if (auto r = IsEngineReady(engine_name); r.has_error()) { + return cpp::fail(r.error()); + } + + auto& e = std::get(engines_[engine_name].engine); + auto res = e->GetRemoteModels(); + if (!res["error"].isNull()) { + return cpp::fail(res["error"].asString()); + } else { + return res; + } } \ No newline at end of file diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 833e4b861..e208fb4ed 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -144,6 +144,9 @@ class EngineService : public EngineServiceI { std::string DeleteEngine(int id); + cpp::result GetRemoteModels( + const std::string& engine_name); + private: cpp::result DownloadEngine( const std::string& engine, const std::string& version = "latest", diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index cfc9f3689..903757d09 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -668,8 +668,7 @@ cpp::result ModelService::StartModel( auto mc = yaml_handler.GetModelConfig(); // Running remote model - if (mc.engine != kLlamaEngine && mc.engine != kOnnxEngine && - mc.engine != kTrtLlmEngine) { + if (remote_engine::IsRemoteEngine(mc.engine)) { config::RemoteModelConfig remote_mc; remote_mc.LoadFromYamlFile( diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 5dab49936..020109fd8 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -3,6 +3,8 @@ constexpr const auto kOnnxEngine = "onnxruntime"; constexpr const auto kLlamaEngine = "llama-cpp"; constexpr const auto kTrtLlmEngine = "tensorrt-llm"; +constexpr const auto kOpenAiEngine = "openai"; +constexpr const auto kAnthropicEngine = "anthropic"; constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; From 1f2a5dc13bf99ae9a6c0f68d7eb6d51499be0043 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 2 Dec 2024 08:31:28 +0700 Subject: [PATCH 30/33] fix: e2e tests --- engine/controllers/engines.cc | 5 ++++- engine/e2e-test/test_cli_model_pull_direct_url.py | 1 + engine/services/model_service.cc | 12 +++++++----- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index c46d84760..3d3c0c037 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -174,7 +174,8 @@ void Engines::InstallEngine( norm_version = version; } - if ((*(req->getJsonObject())).get("type", "").asString() == "remote") { + if ((req->getJsonObject()) && + (*(req->getJsonObject())).get("type", "").asString() == "remote") { auto type = (*(req->getJsonObject())).get("type", "").asString(); auto api_key = (*(req->getJsonObject())).get("api_key", "").asString(); auto url = (*(req->getJsonObject())).get("url", "").asString(); @@ -260,12 +261,14 @@ void Engines::InstallEngine( res["message"] = result.error(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode(k400BadRequest); + CTL_INF("Error: " << result.error()); callback(resp); } else { Json::Value res; res["message"] = "Engine starts installing!"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode(k200OK); + CTL_INF("Engine starts installing!"); callback(resp); } } diff --git a/engine/e2e-test/test_cli_model_pull_direct_url.py b/engine/e2e-test/test_cli_model_pull_direct_url.py index b10d1593d..baa8fa87f 100644 --- a/engine/e2e-test/test_cli_model_pull_direct_url.py +++ b/engine/e2e-test/test_cli_model_pull_direct_url.py @@ -1,4 +1,5 @@ from test_runner import run +from test_runner import start_server, stop_server import os from pathlib import Path diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index cf4a580c2..cd2a0091b 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -64,11 +64,13 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, auto author_id = author.has_value() ? author.value() : "cortexso"; cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = ggufDownloadItem.id, - .author_repo_id = author_id, - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = ggufDownloadItem.id}; + cortex::db::ModelEntry model_entry{ + .model = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = ggufDownloadItem.id, + .status = cortex::db::ModelStatus::Downloaded}; auto result = modellist_utils_obj.AddModelEntry(model_entry, true); if (result.has_error()) { CTL_WRN("Error adding model to modellist: " + result.error()); From 3220ad8b78cc635930e8a0e836183a8d0685edd1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 2 Dec 2024 08:53:42 +0700 Subject: [PATCH 31/33] fix: e2e tests --- engine/database/models.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/database/models.cc b/engine/database/models.cc index 860c60f3f..fb2128396 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -28,7 +28,7 @@ Models::Models(SQLite::Database& db) : db_(db) {} ModelStatus Models::StringToStatus(const std::string& status_str) const { if (status_str == "remote") { return ModelStatus::Remote; - } else if (status_str == "downloaded") { + } else if (status_str == "downloaded" || status_str.empty()) { return ModelStatus::Downloaded; } else if (status_str == "undownloaded") { return ModelStatus::Undownloaded; From 90694c406359b34fe26e474d06a5001ab0c858f2 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 4 Dec 2024 08:58:32 +0700 Subject: [PATCH 32/33] chore: API docs --- docs/static/openapi/cortex.json | 232 +++++++++++++++++++++++++++++++- 1 file changed, 229 insertions(+), 3 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 78430294f..f6f7b7145 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -512,6 +512,73 @@ } } }, + "/v1/models/add": { + "post": { + "operationId": "ModelsController_addModel", + "summary": "Add a model", + "description": "Add a new model configuration to the system.", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddModelRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "model": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "engine": { + "type": "string" + }, + "version": { + "type": "string" + } + } + } + } + }, + "example": { + "message": "Model added successfully!", + "model": { + "model": "claude-3-5-sonnet-20241022", + "engine": "anthropic", + "version": "2023-06-01" + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SimpleErrorResponse" + } + } + } + } + }, + "tags": ["Pulling Models"] + } + }, "/v1/models": { "get": { "operationId": "ModelsController_findAll", @@ -1417,7 +1484,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The type of engine" @@ -1439,6 +1506,31 @@ "type": "string", "description": "The variant of the engine to install (optional)", "example": "mac-arm64" + }, + "type": { + "type": "string", + "description": "The type of connection", + "example": "remote" + }, + "url": { + "type": "string", + "description": "The URL for the API endpoint", + "example": "https://api.openai.com" + }, + "api_key": { + "type": "string", + "description": "The API key for authentication", + "example": "" + }, + "metadata": { + "type": "object", + "properties": { + "get_models_url": { + "type": "string", + "description": "The URL to get models", + "example": "https://api.openai.com/v1/models" + } + } } } } @@ -1475,7 +1567,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The type of engine" @@ -1690,7 +1782,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The name of the engine to update" @@ -3623,6 +3715,109 @@ } } }, + "AddModelRequest": { + "type": "object", + "required": ["model", "engine", "version", "inference_params", "TransformReq", "TransformResp", "metadata"], + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model." + }, + "api_key_template": { + "type": "string", + "description": "Template for the API key header." + }, + "engine": { + "type": "string", + "description": "The engine used for the model." + }, + "version": { + "type": "string", + "description": "The version of the model." + }, + "inference_params": { + "type": "object", + "properties": { + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "frequency_penalty": { + "type": "number" + }, + "presence_penalty": { + "type": "number" + }, + "max_tokens": { + "type": "integer" + }, + "stream": { + "type": "boolean" + } + } + }, + "TransformReq": { + "type": "object", + "properties": { + "get_models": { + "type": "object" + }, + "chat_completions": { + "type": "object", + "properties": { + "url": { + "type": "string" + }, + "template": { + "type": "string" + } + } + }, + "embeddings": { + "type": "object" + } + } + }, + "TransformResp": { + "type": "object", + "properties": { + "chat_completions": { + "type": "object", + "properties": { + "template": { + "type": "string" + } + } + }, + "embeddings": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "properties": { + "author": { + "type": "string" + }, + "description": { + "type": "string" + }, + "end_point": { + "type": "string" + }, + "logo": { + "type": "string" + }, + "api_key_url": { + "type": "string" + } + } + } + } + }, "CreateModelDto": { "type": "object", "properties": { @@ -4292,6 +4487,37 @@ "type": "integer", "description": "Number of GPU layers.", "example": 33 + }, + "api_key_template": { + "type": "string", + "description": "Template for the API key header." + }, + "version": { + "type": "string", + "description": "The version of the model." + }, + "inference_params": { + "type": "object", + "properties": { + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "frequency_penalty": { + "type": "number" + }, + "presence_penalty": { + "type": "number" + }, + "max_tokens": { + "type": "integer" + }, + "stream": { + "type": "boolean" + } + } } } }, From a7e4659a8ce2e1a59b0f27d0ea1b6030ca4c1a57 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 4 Dec 2024 15:35:16 +0700 Subject: [PATCH 33/33] fix: use different interface for remote engine --- docs/static/openapi/cortex.json | 10 +- engine/controllers/models.h | 2 +- engine/cortex-common/remote_enginei.h | 37 ++++++++ .../extensions/remote-engine/remote_engine.cc | 44 --------- .../extensions/remote-engine/remote_engine.h | 8 +- engine/services/engine_service.cc | 39 ++++---- engine/services/engine_service.h | 3 +- engine/services/inference_service.cc | 92 ++++++++++++++----- 8 files changed, 140 insertions(+), 95 deletions(-) create mode 100644 engine/cortex-common/remote_enginei.h diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index f6f7b7145..96ce082e1 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -515,8 +515,8 @@ "/v1/models/add": { "post": { "operationId": "ModelsController_addModel", - "summary": "Add a model", - "description": "Add a new model configuration to the system.", + "summary": "Add a remote model", + "description": "Add a new remote model configuration to the system.", "requestBody": { "required": true, "content": { @@ -1509,17 +1509,17 @@ }, "type": { "type": "string", - "description": "The type of connection", + "description": "The type of connection, remote or local", "example": "remote" }, "url": { "type": "string", - "description": "The URL for the API endpoint", + "description": "The URL for the API endpoint for remote engine", "example": "https://api.openai.com" }, "api_key": { "type": "string", - "description": "The API key for authentication", + "description": "The API key for authentication for remote engine", "example": "" }, "metadata": { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 3227c0999..b2b288adc 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -35,7 +35,7 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); - ADD_METHOD_TO(Models::GetRemoteModels, "/v1/remote/{1}", Get); + ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, diff --git a/engine/cortex-common/remote_enginei.h b/engine/cortex-common/remote_enginei.h new file mode 100644 index 000000000..81ffbf5cd --- /dev/null +++ b/engine/cortex-common/remote_enginei.h @@ -0,0 +1,37 @@ +#pragma once + +#pragma once + +#include +#include + +#include "json/value.h" +#include "trantor/utils/Logger.h" +class RemoteEngineI { + public: + virtual ~RemoteEngineI() {} + + virtual void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void LoadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) = 0; + + // Get list of running models + virtual void GetModels( + std::shared_ptr jsonBody, + std::function&& callback) = 0; + + // Get available remote models + virtual Json::Value GetRemoteModels() = 0; +}; diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index d9aea2f41..04effb457 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -263,9 +263,7 @@ CurlResponse RemoteEngine::MakeGetModelsRequest() { std::string full_url = metadata_["get_models_url"].asString(); struct curl_slist* headers = nullptr; - headers = curl_slist_append(headers, api_key_template_.c_str()); - headers = curl_slist_append(headers, "Content-Type: application/json"); curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); @@ -304,7 +302,6 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( struct curl_slist* headers = nullptr; if (!config.api_key.empty()) { - headers = curl_slist_append(headers, api_key_template_.c_str()); } @@ -707,50 +704,9 @@ void RemoteEngine::HandleEmbedding( callback(Json::Value(), Json::Value()); } -bool RemoteEngine::IsSupported(const std::string& f) { - if (f == "HandleChatCompletion" || f == "LoadModel" || f == "UnloadModel" || - f == "GetModelStatus" || f == "GetModels" || f == "SetFileLogger" || - f == "SetLogLevel") { - return true; - } - return false; -} - -bool RemoteEngine::SetFileLogger(int max_log_lines, - const std::string& log_path) { - if (!async_file_logger_) { - async_file_logger_ = std::make_unique(); - } - - async_file_logger_->setFileName(log_path); - async_file_logger_->setMaxLines(max_log_lines); // Keep last 100000 lines - async_file_logger_->startLogging(); - trantor::Logger::setOutputFunction( - [&](const char* msg, const uint64_t len) { - if (async_file_logger_) - async_file_logger_->output_(msg, len); - }, - [&]() { - if (async_file_logger_) - async_file_logger_->flush(); - }); - freopen(log_path.c_str(), "w", stderr); - freopen(log_path.c_str(), "w", stdout); - return true; -} - -void RemoteEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { - trantor::Logger::setLogLevel(log_level); -} - Json::Value RemoteEngine::GetRemoteModels() { CTL_WRN("Not implemented yet!"); return {}; } -extern "C" { -EngineI* get_engine() { - return new RemoteEngine(); -} -} } // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 153ec6408..8ce6fa652 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -7,7 +7,7 @@ #include #include #include -#include "cortex-common/EngineI.h" +#include "cortex-common/remote_enginei.h" #include "extensions/remote-engine/template_renderer.h" #include "utils/engine_constants.h" #include "utils/file_logger.h" @@ -31,7 +31,7 @@ struct CurlResponse { std::string error_message; }; -class RemoteEngine : public EngineI { +class RemoteEngine : public RemoteEngineI { protected: // Model configuration struct ModelConfig { @@ -95,9 +95,7 @@ class RemoteEngine : public EngineI { void HandleEmbedding( std::shared_ptr json_body, std::function&& callback) override; - bool IsSupported(const std::string& feature) override; - bool SetFileLogger(int max_log_lines, const std::string& log_path) override; - void SetLogLevel(trantor::Logger::LogLevel logLevel) override; + Json::Value GetRemoteModels() override; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 4634a0254..c91fd0dd0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -694,21 +694,6 @@ cpp::result EngineService::LoadEngine( engines_[engine_name].engine = new remote_engine::AnthropicEngine(); } - auto& en = std::get(engines_[ne].engine); - auto config = file_manager_utils::GetCortexConfig(); - if (en->IsSupported("SetFileLogger")) { - en->SetFileLogger(config.maxLogLines, - (std::filesystem::path(config.logFolderPath) / - std::filesystem::path(config.logLlamaCppPath)) - .string()); - } else { - CTL_WRN("Method SetFileLogger is not supported yet"); - } - if (en->IsSupported("SetLogLevel")) { - en->SetLogLevel(trantor::Logger::logLevel()); - } else { - CTL_WRN("Method SetLogLevel is not supported yet"); - } CTL_INF("Loaded engine: " << engine_name); return {}; } @@ -883,8 +868,11 @@ cpp::result EngineService::UnloadEngine( if (!IsEngineLoaded(ne)) { return cpp::fail("Engine " + ne + " is not loaded yet!"); } - EngineI* e = std::get(engines_[ne].engine); - delete e; + if (std::holds_alternative(engines_[ne].engine)) { + delete std::get(engines_[ne].engine); + } else { + delete std::get(engines_[ne].engine); + } #if defined(_WIN32) if (!RemoveDllDirectory(engines_[ne].cookie)) { @@ -1100,7 +1088,22 @@ cpp::result EngineService::GetRemoteModels( return cpp::fail(r.error()); } - auto& e = std::get(engines_[engine_name].engine); + if (!IsEngineLoaded(engine_name)) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + + CTL_INF("Loaded engine: " << engine_name); + } + + auto& e = std::get(engines_[engine_name].engine); auto res = e->GetRemoteModels(); if (!res["error"].isNull()) { return cpp::fail(res["error"].asString()); diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 692f7d5f5..8c8bfbbe6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -11,6 +11,7 @@ #include "common/engine_servicei.h" #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" +#include "cortex-common/remote_enginei.h" #include "database/engines.h" #include "extensions/remote-engine/remote_engine.h" #include "services/download_service.h" @@ -37,7 +38,7 @@ struct EngineUpdateResult { } }; -using EngineV = std::variant; +using EngineV = std::variant; class EngineService : public EngineServiceI { private: diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 46309823d..ace7e675f 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -24,14 +24,26 @@ cpp::result InferenceService::HandleChatCompletion( return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleChatCompletion( - json_body, [q, tool_choice](Json::Value status, Json::Value res) { - if (!tool_choice.isNull()) { - res["tool_choice"] = tool_choice; - } - q->push(std::make_pair(status, res)); - }); + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } + return {}; } @@ -53,10 +65,18 @@ cpp::result InferenceService::HandleEmbedding( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } return {}; } @@ -83,11 +103,20 @@ InferResult InferenceService::LoadModel( // might need mutex here auto engine_result = engine_service_->GetLoadedEngine(engine_type); - auto engine = std::get(engine_result.value()); - engine->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } return std::make_pair(stt, r); } @@ -110,12 +139,22 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name, json_body["model"] = model_id; LOG_TRACE << "Start unload model"; - auto engine = std::get(engine_result.value()); - engine->UnloadModel(std::make_shared(json_body), + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), + [&r, &stt](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), [&r, &stt](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); } @@ -141,12 +180,23 @@ InferResult InferenceService::GetModelStatus( } LOG_TRACE << "Start to get model status"; - auto engine = std::get(engine_result.value()); - engine->GetModelStatus(json_body, + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->GetModelStatus(json_body, + [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->GetModelStatus(json_body, [&stt, &r](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); }