|
2 | 2 |
|
3 | 3 | #include "CatBoostLibraryHandlerFactory.h" |
4 | 4 | #include "Common/ProfileEvents.h" |
5 | | -#include "ExternalDictionaryLibraryHandler.h" |
6 | 5 | #include "ExternalDictionaryLibraryHandlerFactory.h" |
7 | 6 |
|
8 | 7 | #include <Formats/FormatFactory.h> |
9 | 8 | #include <IO/ReadBufferFromString.h> |
10 | 9 | #include <IO/ReadHelpers.h> |
11 | 10 | #include <Common/BridgeProtocolVersion.h> |
| 11 | +#include <Common/filesystemHelpers.h> |
12 | 12 | #include <IO/WriteHelpers.h> |
13 | 13 | #include <Poco/Net/HTTPServerRequest.h> |
14 | 14 | #include <Poco/Net/HTTPServerResponse.h> |
|
24 | 24 | #include <Formats/NativeReader.h> |
25 | 25 | #include <Formats/NativeWriter.h> |
26 | 26 | #include <DataTypes/DataTypesNumber.h> |
27 | | -#include <DataTypes/DataTypeString.h> |
| 27 | +#include <filesystem> |
| 28 | +#include <boost/algorithm/string/join.hpp> |
28 | 29 |
|
29 | 30 |
|
30 | 31 | namespace DB |
@@ -86,14 +87,26 @@ static void writeData(Block data, OutputFormatPtr format) |
86 | 87 | } |
87 | 88 |
|
88 | 89 |
|
89 | | -ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_) |
| 90 | +ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_, std::vector<std::string> libraries_paths_) |
90 | 91 | : WithContext(context_) |
91 | 92 | , keep_alive_timeout(keep_alive_timeout_) |
92 | 93 | , log(getLogger("ExternalDictionaryLibraryBridgeRequestHandler")) |
| 94 | + , libraries_paths(std::move(libraries_paths_)) |
93 | 95 | { |
94 | 96 | } |
95 | 97 |
|
96 | 98 |
|
| 99 | +static bool checkLibraryPath(const std::string & path, const std::vector<std::string> & allowed_prefixes, HTTPServerResponse & response) |
| 100 | +{ |
| 101 | + for (const auto & prefix : allowed_prefixes) |
| 102 | + if (fileOrSymlinkPathStartsWith(path, prefix)) |
| 103 | + return true; |
| 104 | + processError(response, fmt::format("The provided library path {} is not inside any of the allowed prefixes {} ('libraries-path') from the configuration.", |
| 105 | + path, boost::join(allowed_prefixes, ", "))); |
| 106 | + return false; |
| 107 | +} |
| 108 | + |
| 109 | + |
97 | 110 | void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & /*write_event*/) |
98 | 111 | { |
99 | 112 | LOG_TRACE(log, "Request URI: {}", request.getURI()); |
@@ -172,12 +185,15 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ |
172 | 185 |
|
173 | 186 | if (!params.has("library_path")) |
174 | 187 | { |
175 | | - processError(response, "No 'library_path' in request URL"); |
| 188 | + processError(response, "No 'library_path' in the request URL"); |
176 | 189 | return; |
177 | 190 | } |
178 | 191 |
|
179 | 192 | const String & library_path = params.get("library_path"); |
180 | 193 |
|
| 194 | + if (!checkLibraryPath(library_path, libraries_paths, response)) |
| 195 | + return; |
| 196 | + |
181 | 197 | if (!params.has("library_settings")) |
182 | 198 | { |
183 | 199 | processError(response, "No 'library_settings' in request URL"); |
@@ -413,10 +429,11 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque |
413 | 429 |
|
414 | 430 |
|
415 | 431 | CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler( |
416 | | - size_t keep_alive_timeout_, ContextPtr context_) |
| 432 | + size_t keep_alive_timeout_, ContextPtr context_, std::vector<std::string> libraries_paths_) |
417 | 433 | : WithContext(context_) |
418 | 434 | , keep_alive_timeout(keep_alive_timeout_) |
419 | 435 | , log(getLogger("CatBoostLibraryBridgeRequestHandler")) |
| 436 | + , libraries_paths(std::move(libraries_paths_)) |
420 | 437 | { |
421 | 438 | } |
422 | 439 |
|
@@ -522,6 +539,9 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ |
522 | 539 |
|
523 | 540 | const String & library_path = params.get("library_path"); |
524 | 541 |
|
| 542 | + if (!checkLibraryPath(library_path, libraries_paths, response)) |
| 543 | + return; |
| 544 | + |
525 | 545 | if (!params.has("model_path")) |
526 | 546 | { |
527 | 547 | processError(response, "No 'model_path' in request URL"); |
|
0 commit comments