Skip to content

Commit 0de8878

Browse files
ngxsonggerganov
andauthored
server: split HTTP into its own interface (#17216)
* server: split HTTP into its own interface * move server-http and httplib to its own file * add the remaining endpoints * fix exception/error handling * renaming * missing header * fix missing windows header * fix error responses from http layer * fix slot save/restore handler * fix case where only one stream chunk is returned * add NOMINMAX * do not call sink.write on empty data * use safe_json_to_str for SSE * clean up * add some comments * improve usage of next() * bring back the "server is listening on" message * more generic handler * add req.headers * move the chat template print to init() * add req.path * cont : minor --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 38e2c1b commit 0de8878

File tree

5 files changed

+1155
-840
lines changed

5 files changed

+1155
-840
lines changed

tools/server/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ endif()
1414
set(TARGET_SRCS
1515
server.cpp
1616
utils.hpp
17+
server-http.cpp
18+
server-http.h
1719
)
1820
set(PUBLIC_ASSETS
1921
index.html.gz

tools/server/server-http.cpp

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
#include "utils.hpp"
2+
#include "common.h"
3+
#include "server-http.h"
4+
5+
#include <cpp-httplib/httplib.h>
6+
7+
#include <functional>
8+
#include <string>
9+
#include <thread>
10+
11+
// auto generated files (see README.md for details)
12+
#include "index.html.gz.hpp"
13+
#include "loading.html.hpp"
14+
15+
//
16+
// HTTP implementation using cpp-httplib
17+
//
18+
19+
class server_http_context::Impl {
20+
public:
21+
std::unique_ptr<httplib::Server> srv;
22+
};
23+
24+
server_http_context::server_http_context()
25+
: pimpl(std::make_unique<server_http_context::Impl>())
26+
{}
27+
28+
server_http_context::~server_http_context() = default;
29+
30+
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
31+
// skip GH copilot requests when using default port
32+
if (req.path == "/v1/health") {
33+
return;
34+
}
35+
36+
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
37+
38+
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
39+
40+
SRV_DBG("request: %s\n", req.body.c_str());
41+
SRV_DBG("response: %s\n", res.body.c_str());
42+
}
43+
44+
bool server_http_context::init(const common_params & params) {
45+
path_prefix = params.api_prefix;
46+
port = params.port;
47+
hostname = params.hostname;
48+
49+
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
50+
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
51+
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
52+
svr.reset(
53+
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
54+
);
55+
} else {
56+
LOG_INF("Running without SSL\n");
57+
svr.reset(new httplib::Server());
58+
}
59+
#else
60+
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
61+
LOG_ERR("Server is built without SSL support\n");
62+
return false;
63+
}
64+
pimpl->srv.reset(new httplib::Server());
65+
#endif
66+
67+
auto & srv = pimpl->srv;
68+
srv->set_default_headers({{"Server", "llama.cpp"}});
69+
srv->set_logger(log_server_request);
70+
srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
71+
// this is fail-safe; exceptions should already handled by `ex_wrapper`
72+
73+
std::string message;
74+
try {
75+
std::rethrow_exception(ep);
76+
} catch (const std::exception & e) {
77+
message = e.what();
78+
} catch (...) {
79+
message = "Unknown Exception";
80+
}
81+
82+
res.status = 500;
83+
res.set_content(message, "text/plain");
84+
LOG_ERR("got exception: %s\n", message.c_str());
85+
});
86+
87+
srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
88+
if (res.status == 404) {
89+
res.set_content(
90+
safe_json_to_str(json {
91+
{"error", {
92+
{"message", "File Not Found"},
93+
{"type", "not_found_error"},
94+
{"code", 404}
95+
}}
96+
}),
97+
"application/json; charset=utf-8"
98+
);
99+
}
100+
// for other error codes, we skip processing here because it's already done by res->error()
101+
});
102+
103+
// set timeouts and change hostname and port
104+
srv->set_read_timeout (params.timeout_read);
105+
srv->set_write_timeout(params.timeout_write);
106+
107+
if (params.api_keys.size() == 1) {
108+
auto key = params.api_keys[0];
109+
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
110+
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
111+
} else if (params.api_keys.size() > 1) {
112+
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
113+
}
114+
115+
//
116+
// Middlewares
117+
//
118+
119+
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
120+
static const std::unordered_set<std::string> public_endpoints = {
121+
"/health",
122+
"/v1/health",
123+
"/models",
124+
"/v1/models",
125+
"/api/tags"
126+
};
127+
128+
// If API key is not set, skip validation
129+
if (api_keys.empty()) {
130+
return true;
131+
}
132+
133+
// If path is public or is static file, skip validation
134+
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
135+
return true;
136+
}
137+
138+
// Check for API key in the header
139+
auto auth_header = req.get_header_value("Authorization");
140+
141+
std::string prefix = "Bearer ";
142+
if (auth_header.substr(0, prefix.size()) == prefix) {
143+
std::string received_api_key = auth_header.substr(prefix.size());
144+
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
145+
return true; // API key is valid
146+
}
147+
}
148+
149+
// API key is invalid or not provided
150+
res.status = 401;
151+
res.set_content(
152+
safe_json_to_str(json {
153+
{"error", {
154+
{"message", "Invalid API Key"},
155+
{"type", "authentication_error"},
156+
{"code", 401}
157+
}}
158+
}),
159+
"application/json; charset=utf-8"
160+
);
161+
162+
LOG_WRN("Unauthorized: Invalid API Key\n");
163+
164+
return false;
165+
};
166+
167+
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
168+
bool ready = is_ready.load();
169+
if (!ready) {
170+
auto tmp = string_split<std::string>(req.path, '.');
171+
if (req.path == "/" || tmp.back() == "html") {
172+
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
173+
res.status = 503;
174+
} else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
175+
// allow the models endpoint to be accessed during loading
176+
return true;
177+
} else {
178+
res.status = 503;
179+
res.set_content(
180+
safe_json_to_str(json {
181+
{"error", {
182+
{"message", "Loading model"},
183+
{"type", "unavailable_error"},
184+
{"code", 503}
185+
}}
186+
}),
187+
"application/json; charset=utf-8"
188+
);
189+
}
190+
return false;
191+
}
192+
return true;
193+
};
194+
195+
// register server middlewares
196+
srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
197+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
198+
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
199+
if (req.method == "OPTIONS") {
200+
res.set_header("Access-Control-Allow-Credentials", "true");
201+
res.set_header("Access-Control-Allow-Methods", "GET, POST");
202+
res.set_header("Access-Control-Allow-Headers", "*");
203+
res.set_content("", "text/html"); // blank response, no data
204+
return httplib::Server::HandlerResponse::Handled; // skip further processing
205+
}
206+
if (!middleware_server_state(req, res)) {
207+
return httplib::Server::HandlerResponse::Handled;
208+
}
209+
if (!middleware_validate_api_key(req, res)) {
210+
return httplib::Server::HandlerResponse::Handled;
211+
}
212+
return httplib::Server::HandlerResponse::Unhandled;
213+
});
214+
215+
int n_threads_http = params.n_threads_http;
216+
if (n_threads_http < 1) {
217+
// +2 threads for monitoring endpoints
218+
n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
219+
}
220+
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
221+
srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
222+
223+
//
224+
// Web UI setup
225+
//
226+
227+
if (!params.webui) {
228+
LOG_INF("Web UI is disabled\n");
229+
} else {
230+
// register static assets routes
231+
if (!params.public_path.empty()) {
232+
// Set the base directory for serving static files
233+
bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
234+
if (!is_found) {
235+
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
236+
return 1;
237+
}
238+
} else {
239+
// using embedded static index.html
240+
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
241+
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
242+
res.set_content("Error: gzip is not supported by this browser", "text/plain");
243+
} else {
244+
res.set_header("Content-Encoding", "gzip");
245+
// COEP and COOP headers, required by pyodide (python interpreter)
246+
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
247+
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
248+
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
249+
}
250+
return false;
251+
});
252+
}
253+
}
254+
return true;
255+
}
256+
257+
bool server_http_context::start() {
258+
// Bind and listen
259+
260+
auto & srv = pimpl->srv;
261+
bool was_bound = false;
262+
bool is_sock = false;
263+
if (string_ends_with(std::string(hostname), ".sock")) {
264+
is_sock = true;
265+
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
266+
srv->set_address_family(AF_UNIX);
267+
// bind_to_port requires a second arg, any value other than 0 should
268+
// simply get ignored
269+
was_bound = srv->bind_to_port(hostname, 8080);
270+
} else {
271+
LOG_INF("%s: binding port with default address family\n", __func__);
272+
// bind HTTP listen port
273+
if (port == 0) {
274+
int bound_port = srv->bind_to_any_port(hostname);
275+
was_bound = (bound_port >= 0);
276+
if (was_bound) {
277+
port = bound_port;
278+
}
279+
} else {
280+
was_bound = srv->bind_to_port(hostname, port);
281+
}
282+
}
283+
284+
if (!was_bound) {
285+
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
286+
return false;
287+
}
288+
289+
// run the HTTP server in a thread
290+
thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
291+
srv->wait_until_ready();
292+
293+
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
294+
: string_format("http://%s:%d", hostname.c_str(), port);
295+
return true;
296+
}
297+
298+
void server_http_context::stop() const {
299+
if (pimpl->srv) {
300+
pimpl->srv->stop();
301+
}
302+
}
303+
304+
static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
305+
for (const auto & [key, value] : headers) {
306+
res.set_header(key, value);
307+
}
308+
}
309+
310+
static std::map<std::string, std::string> get_params(const httplib::Request & req) {
311+
std::map<std::string, std::string> params;
312+
for (const auto & [key, value] : req.params) {
313+
params[key] = value;
314+
}
315+
for (const auto & [key, value] : req.path_params) {
316+
params[key] = value;
317+
}
318+
return params;
319+
}
320+
321+
static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
322+
std::map<std::string, std::string> headers;
323+
for (const auto & [key, value] : req.headers) {
324+
headers[key] = value;
325+
}
326+
return headers;
327+
}
328+
329+
static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) {
330+
if (response->is_stream()) {
331+
res.status = response->status;
332+
set_headers(res, response->headers);
333+
std::string content_type = response->content_type;
334+
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
335+
std::shared_ptr<server_http_res> r_ptr = std::move(response);
336+
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
337+
std::string chunk;
338+
bool has_next = response->next(chunk);
339+
if (!chunk.empty()) {
340+
// TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
341+
sink.write(chunk.data(), chunk.size());
342+
SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
343+
}
344+
if (!has_next) {
345+
sink.done();
346+
SRV_DBG("%s", "http: stream ended\n");
347+
}
348+
return has_next;
349+
};
350+
const auto on_complete = [response = r_ptr](bool) mutable {
351+
response.reset(); // trigger the destruction of the response object
352+
};
353+
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
354+
} else {
355+
res.status = response->status;
356+
set_headers(res, response->headers);
357+
res.set_content(response->data, response->content_type);
358+
}
359+
}
360+
361+
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
362+
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
363+
server_http_res_ptr response = handler(server_http_req{
364+
get_params(req),
365+
get_headers(req),
366+
req.path,
367+
req.body,
368+
req.is_connection_closed
369+
});
370+
process_handler_response(response, res);
371+
});
372+
}
373+
374+
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
375+
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
376+
server_http_res_ptr response = handler(server_http_req{
377+
get_params(req),
378+
get_headers(req),
379+
req.path,
380+
req.body,
381+
req.is_connection_closed
382+
});
383+
process_handler_response(response, res);
384+
});
385+
}
386+

0 commit comments

Comments
 (0)