Skip to content

Commit 5db24f8

Browse files
Support v3/models endpoint compatible with OpenAI (#3447)
1 parent 216e5fc commit 5db24f8

File tree

6 files changed

+323
-6
lines changed

6 files changed

+323
-6
lines changed

src/BUILD

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,6 +2464,7 @@ cc_test(
24642464
+ select({
24652465
"//:not_disable_mediapipe": [
24662466
"test/embeddingsnode_test.cpp",
2467+
"test/listmodelsendpoint_test.cpp",
24672468
"test/mediapipeflow_test.cpp",
24682469
"test/mediapipe/inputsidepacketusertestcalc.cc",
24692470
"test/reranknode_test.cpp",
@@ -2648,7 +2649,7 @@ cc_test(
26482649
"//src:custom_nodes_common_buffersqueue",
26492650
"@com_google_googletest//:gtest",
26502651
":pull_hf_model_test",
2651-
":listmodels_test",
2652+
":listdirectorymodels_test",
26522653
":graph_export_test",
26532654
":config_export_test",
26542655
] + select({
@@ -2773,8 +2774,8 @@ cc_library(
27732774
copts = COPTS_TESTS,
27742775
)
27752776
cc_library(
2776-
name = "listmodels_test",
2777-
srcs = ["test/listmodels_test.cpp"],
2777+
name = "listdirectorymodels_test",
2778+
srcs = ["test/listdirectorymodels_test.cpp"],
27782779
alwayslink = True,
27792780
linkopts = [],
27802781
deps = [
@@ -2787,6 +2788,7 @@ cc_library(
27872788
local_defines = COMMON_LOCAL_DEFINES,
27882789
copts = COPTS_TESTS,
27892790
)
2791+
27902792
cc_library(
27912793
name = "graph_export_test",
27922794
linkstatic = 1,

src/http_rest_api_handler.cpp

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <unordered_map>
2828
#include <utility>
2929
#include <vector>
30+
#include <ctime>
3031

3132
#ifndef _WIN32
3233
#include <curl/curl.h>
@@ -113,6 +114,10 @@ const std::string HttpRestApiHandler::kfs_serverliveRegexExp =
113114
const std::string HttpRestApiHandler::kfs_servermetadataRegexExp =
114115
R"(/v2)";
115116

117+
const std::string HttpRestApiHandler::v3_ListModelsRegexExp =
118+
R"(/v3/(v1/)?models)";
119+
const std::string HttpRestApiHandler::v3_RetrieveModelRegexExp =
120+
R"(/v3/(v1/)?models/(.+))";
116121
const std::string HttpRestApiHandler::v3_RegexExp =
117122
R"(/v3/.*?(/|$))";
118123

@@ -129,6 +134,8 @@ HttpRestApiHandler::HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_
129134
kfs_serverreadyRegex(kfs_serverreadyRegexExp),
130135
kfs_serverliveRegex(kfs_serverliveRegexExp),
131136
kfs_servermetadataRegex(kfs_servermetadataRegexExp),
137+
v3_ListModelsRegex(v3_ListModelsRegexExp),
138+
v3_RetrieveModelRegex(v3_RetrieveModelRegexExp),
132139
v3_Regex(v3_RegexExp),
133140
metricsRegex(metricsRegexExp),
134141
timeout_in_ms(timeout_in_ms),
@@ -206,6 +213,12 @@ void HttpRestApiHandler::registerAll() {
206213
return processServerMetadataKFSRequest(request_components, response, request_body);
207214
});
208215

216+
registerHandler(V3_ListModels, [this](const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, HttpResponseComponents& response_components, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) {
217+
return processListModelsRequest(response);
218+
});
219+
registerHandler(V3_RetrieveModel, [this](const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, HttpResponseComponents& response_components, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) {
220+
return processRetrieveModelRequest(request_components.model_name, response);
221+
});
209222
registerHandler(V3, [this](const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, HttpResponseComponents& response_components, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) -> Status {
210223
OVMS_PROFILE_FUNCTION();
211224
return processV3(uri, request_components, response, request_body, std::move(serverReaderWriter), std::move(multiPartParser));
@@ -559,6 +572,89 @@ static Status createV3HttpPayload(
559572
}
560573
#endif
561574

575+
void parseModel(rapidjson::Writer<rapidjson::StringBuffer>& writer, const std::string& name, const time_t timestamp) {
576+
writer.StartObject();
577+
writer.String("id");
578+
writer.String(name.c_str());
579+
writer.String("object");
580+
writer.String("model");
581+
writer.String("created");
582+
writer.Int64(timestamp);
583+
writer.String("owned_by");
584+
writer.String("OVMS");
585+
writer.EndObject();
586+
}
587+
588+
Status HttpRestApiHandler::processRetrieveModelRequest(const std::string& name, std::string& response) {
589+
const std::map<std::string, std::shared_ptr<Model>>& models = modelManager.getModels();
590+
bool exist = false;
591+
auto it = models.find(name);
592+
if (it != models.end())
593+
exist = true;
594+
const std::vector<std::string>& pipelinesNames = modelManager.getPipelineFactory().getPipelinesNames();
595+
if (std::find(pipelinesNames.begin(), pipelinesNames.end(), name) != pipelinesNames.end())
596+
exist = true;
597+
#if (MEDIAPIPE_DISABLE == 0)
598+
auto mediapipes = modelManager.getMediapipeFactory().getMediapipePipelinesNames();
599+
if (std::find(mediapipes.begin(), mediapipes.end(), name) != mediapipes.end())
600+
exist = true;
601+
#endif
602+
rapidjson::StringBuffer buffer;
603+
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
604+
if (!exist) {
605+
writer.StartObject();
606+
writer.String("error");
607+
writer.String("Model not found");
608+
writer.EndObject();
609+
response = buffer.GetString();
610+
return StatusCode::MODEL_NOT_LOADED;
611+
}
612+
time_t timestamp;
613+
time(&timestamp);
614+
writer.StartObject();
615+
writer.String("object");
616+
writer.String("list");
617+
writer.String("data");
618+
writer.StartArray();
619+
parseModel(writer, name, timestamp);
620+
writer.EndArray();
621+
writer.EndObject();
622+
response = buffer.GetString();
623+
return StatusCode::OK;
624+
}
625+
626+
Status HttpRestApiHandler::processListModelsRequest(std::string& response) {
627+
const std::map<std::string, std::shared_ptr<Model>>& models = modelManager.getModels();
628+
rapidjson::StringBuffer buffer;
629+
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
630+
time_t timestamp;
631+
time(&timestamp);
632+
writer.StartObject();
633+
writer.String("object");
634+
writer.String("list");
635+
writer.String("data");
636+
writer.StartArray();
637+
for (auto const& model : models) {
638+
parseModel(writer, model.first, timestamp);
639+
}
640+
const std::vector<std::string>& pipelinesNames = modelManager.getPipelineFactory().getPipelinesNames();
641+
for (auto const& pipelineName : pipelinesNames) {
642+
parseModel(writer, pipelineName, timestamp);
643+
}
644+
#if (MEDIAPIPE_DISABLE == 0)
645+
auto mediapipes = modelManager.getMediapipeFactory().getMediapipePipelinesNames();
646+
for (auto const& graphName : mediapipes) {
647+
parseModel(writer, graphName, timestamp);
648+
}
649+
#endif
650+
writer.EndArray();
651+
writer.String("object");
652+
writer.String("list");
653+
writer.EndObject();
654+
response = buffer.GetString();
655+
return StatusCode::OK;
656+
}
657+
562658
Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) {
563659
#if (MEDIAPIPE_DISABLE == 0)
564660
OVMS_PROFILE_FUNCTION();
@@ -817,6 +913,8 @@ Status HttpRestApiHandler::parseRequestComponents(HttpRequestComponents& request
817913
std::regex_match(request_path, sm, kfs_servermetadataRegex) ||
818914
std::regex_match(request_path, sm, kfs_modelmetadataRegex) ||
819915
std::regex_match(request_path, sm, kfs_modelreadyRegex) ||
916+
std::regex_match(request_path, sm, v3_ListModelsRegex) ||
917+
std::regex_match(request_path, sm, v3_RetrieveModelRegex) ||
820918
std::regex_match(request_path, sm, metricsRegex))
821919
? StatusCode::REST_UNSUPPORTED_METHOD
822920
: StatusCode::REST_INVALID_URL;
@@ -886,6 +984,15 @@ Status HttpRestApiHandler::parseRequestComponents(HttpRequestComponents& request
886984
requestComponents.type = Metrics;
887985
return StatusCode::OK;
888986
}
987+
if (std::regex_match(request_path, sm, v3_ListModelsRegex)) {
988+
requestComponents.type = V3_ListModels;
989+
return StatusCode::OK;
990+
}
991+
if (std::regex_match(request_path, sm, v3_RetrieveModelRegex)) {
992+
requestComponents.model_name = urlDecode(sm[2]);
993+
requestComponents.type = V3_RetrieveModel;
994+
return StatusCode::OK;
995+
}
889996
return (std::regex_match(request_path, sm, predictionRegex) ||
890997
std::regex_match(request_path, sm, kfs_inferRegex, std::regex_constants::match_any) ||
891998
std::regex_match(request_path, sm, configReloadRegex))
@@ -908,20 +1015,17 @@ Status HttpRestApiHandler::processRequest(
9081015
HttpResponseComponents& responseComponents,
9091016
std::shared_ptr<HttpAsyncWriter> serverReaderWriter,
9101017
std::shared_ptr<MultiPartParser> multiPartParser) {
911-
9121018
std::smatch sm;
9131019
std::string request_path_str(request_path);
9141020
if (FileSystem::isPathEscaped(request_path_str)) {
9151021
SPDLOG_DEBUG("Path {} escape with .. is forbidden.", request_path);
9161022
return StatusCode::PATH_INVALID;
9171023
}
918-
9191024
HttpRequestComponents requestComponents;
9201025
auto status = parseRequestComponents(requestComponents, http_method, request_path_str, *headers);
9211026

9221027
if (!status.ok())
9231028
return status;
924-
9251029
response->clear();
9261030
return dispatchToProcessor(request_path, request_body, response, requestComponents, responseComponents, std::move(serverReaderWriter), std::move(multiPartParser));
9271031
}

src/http_rest_api_handler.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ enum RequestType { Predict,
5656
KFS_GetServerReady,
5757
KFS_GetServerLive,
5858
KFS_GetServerMetadata,
59+
V3_ListModels,
60+
V3_RetrieveModel,
5961
V3,
6062
Metrics,
6163
Options };
@@ -104,7 +106,10 @@ class HttpRestApiHandler {
104106
static const std::string kfs_serverliveRegexExp;
105107
static const std::string kfs_servermetadataRegexExp;
106108

109+
static const std::string v3_ListModelsRegexExp;
110+
static const std::string v3_RetrieveModelRegexExp;
107111
static const std::string v3_RegexExp;
112+
108113
/**
109114
* @brief Construct a new HttpRest Api Handler
110115
*
@@ -234,6 +239,8 @@ class HttpRestApiHandler {
234239
Status processServerMetadataKFSRequest(const HttpRequestComponents& request_components, std::string& response, const std::string& request_body);
235240

236241
Status processV3(const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser);
242+
Status processListModelsRequest(std::string& response);
243+
Status processRetrieveModelRequest(const std::string& name, std::string& response);
237244

238245
private:
239246
const std::regex predictionRegex;
@@ -249,6 +256,8 @@ class HttpRestApiHandler {
249256
const std::regex kfs_serverliveRegex;
250257
const std::regex kfs_servermetadataRegex;
251258

259+
const std::regex v3_ListModelsRegex;
260+
const std::regex v3_RetrieveModelRegex;
252261
const std::regex v3_Regex;
253262

254263
const std::regex metricsRegex;

0 commit comments

Comments
 (0)