Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 92 additions & 21 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@
#include <string>
#include <thread>
#include <vector>
#include <memory>
#include <csignal>
#include <atomic>
#include <functional>
#include <cstdlib>
#if defined (_WIN32)
#include <windows.h>
#endif

using namespace httplib;
using json = nlohmann::ordered_json;

enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
};

namespace {

// output formats
Expand All @@ -26,6 +39,20 @@ const std::string srt_format = "srt";
const std::string vjson_format = "verbose_json";
const std::string vtt_format = "vtt";

std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;

inline void signal_handler(int signal) {
if (is_terminating.test_and_set()) {
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
// this is for better developer experience, we can remove when the server is stable enough
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
exit(1);
}

shutdown_handler(signal);
}

struct server_params
{
std::string hostname = "127.0.0.1";
Expand Down Expand Up @@ -593,6 +620,9 @@ int main(int argc, char ** argv) {
}
}

std::unique_ptr<httplib::Server> svr = std::make_unique<httplib::Server>();
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);

if (ctx == nullptr) {
Expand All @@ -602,9 +632,10 @@ int main(int argc, char ** argv) {

// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
state.store(SERVER_STATE_READY);


Server svr;
svr.set_default_headers({{"Server", "whisper.cpp"},
svr->set_default_headers({{"Server", "whisper.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type, authorization"}});

Expand Down Expand Up @@ -683,15 +714,15 @@ int main(int argc, char ** argv) {
whisper_params default_params = params;

// this is only called if no index.html is found in the public --path
svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
svr->Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
res.set_content(default_content, "text/html");
return false;
});

svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
});

svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
// acquire whisper model mutex lock
std::lock_guard<std::mutex> lock(whisper_mutex);

Expand Down Expand Up @@ -997,8 +1028,9 @@ int main(int argc, char ** argv) {
// reset params to their defaults
params = default_params;
});
svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
std::lock_guard<std::mutex> lock(whisper_mutex);
state.store(SERVER_STATE_LOADING_MODEL);
if (!req.has_file("model"))
{
fprintf(stderr, "error: no 'model' field in the request\n");
Expand Down Expand Up @@ -1030,18 +1062,25 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);

state.store(SERVER_STATE_READY);
const std::string success = "Load was successful!";
res.set_content(success, "application/text");

// check if the model is in the file system
});

svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json");
svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){
server_state current_state = state.load();
if (current_state == SERVER_STATE_READY) {
const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json");
} else {
res.set_content("{\"status\":\"loading model\"}", "application/json");
res.status = 503;
}
});

svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try {
Expand All @@ -1055,7 +1094,7 @@ int main(int argc, char ** argv) {
res.status = 500;
});

svr.set_error_handler([](const Request &req, Response &res) {
svr->set_error_handler([](const Request &req, Response &res) {
if (res.status == 400) {
res.set_content("Invalid request", "text/plain");
} else if (res.status != 500) {
Expand All @@ -1065,29 +1104,61 @@ int main(int argc, char ** argv) {
});

// set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
svr->set_read_timeout(sparams.read_timeout);
svr->set_write_timeout(sparams.write_timeout);

if (!svr.bind_to_port(sparams.hostname, sparams.port))
if (!svr->bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
sparams.hostname.c_str(), sparams.port);
return 1;
}

// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
svr->set_base_dir(sparams.public_path);

// to make it ctrl+clickable:
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);

if (!svr.listen_after_bind())
{
return 1;
}
shutdown_handler = [&](int signal) {
printf("\nCaught signal %d, shutting down gracefully...\n", signal);
if (svr) {
svr->stop();
}
};

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif

// clean up function, to be called before exit
auto clean_up = [&ctx]() {
whisper_print_timings(ctx);
whisper_free(ctx);
};

std::thread t([&] {
if (!svr->listen_after_bind()) {
fprintf(stderr, "error: server listen failed\n");
}
});

svr->wait_until_ready();

t.join();


whisper_print_timings(ctx);
whisper_free(ctx);
clean_up();

return 0;
}
Loading