diff --git a/examples/server/bench.js b/examples/server/bench.js index 127544357eb..baf8c975f55 100644 --- a/examples/server/bench.js +++ b/examples/server/bench.js @@ -2,8 +2,13 @@ import http from 'k6/http' import { check } from 'k6' export let options = { - vus: parseInt(__ENV.CONCURRENCY) || 4, - iterations: parseInt(__ENV.CONCURRENCY) || 4, + scenarios: { + load_test: { + executor: 'constant-vus', + vus: parseInt(__ENV.CONCURRENCY) || 8, + duration: '1m', + }, + }, } const filePath = __ENV.FILE_PATH diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8b6c5a96720..f9eb7161194 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #if defined (_WIN32) #include #endif @@ -41,10 +43,10 @@ const std::string vjson_format = "verbose_json"; const std::string vtt_format = "vtt"; std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; +std::atomic is_terminating{false}; inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { + if (is_terminating.exchange(true)) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice // this is for better developer experience, we can remove when the server is stable enough fprintf(stderr, "Received second interrupt, terminating immediately.\n"); @@ -64,6 +66,7 @@ struct server_params int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; + int32_t workers = 1; bool ffmpeg_converter = false; }; @@ -169,6 +172,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); + fprintf(stderr, " --workers N, [%-7d] Number of worker threads for the server\n", sparams.workers); fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); @@ -241,6 +245,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--host") { sparams.hostname = argv[++i]; } + else if ( arg == "--workers") { sparams.workers = std::stoi(argv[++i]); } else if ( arg == "--public") { sparams.public_path = argv[++i]; } else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } @@ -603,14 +608,451 @@ void get_req_parameters(const Request & req, whisper_params & params) } // namespace +enum class AbortReason { + NotAborted, + ClientDisconnect, + ServerShutdown +}; + +struct WhisperTaskResult { + bool success = false; + std::string error_msg; + std::string output; + std::string content_type = "application/json"; + int status_code = 200; +}; + +struct WhisperTask { + std::vector pcmf32; + std::vector> pcmf32s; + whisper_params params; + std::string filename; + const httplib::Request* request_ptr; // For abort callback + std::atomic abort_reason{AbortReason::NotAborted}; + std::atomic* stop_flag_ptr{nullptr}; + + // Synchronization for result + std::mutex result_mutex; + std::condition_variable result_cv; + WhisperTaskResult result; + std::atomic completed{false}; + + // Delete copy constructor and copy assignment + WhisperTask(const WhisperTask&) = delete; + WhisperTask& operator=(const WhisperTask&) = delete; + + // Default constructor + WhisperTask() = default; + + // Move constructor + WhisperTask(WhisperTask&& other) noexcept + : pcmf32(std::move(other.pcmf32)), + pcmf32s(std::move(other.pcmf32s)), + params(std::move(other.params)), + filename(std::move(other.filename)), + request_ptr(other.request_ptr), + abort_reason(other.abort_reason.load()), + stop_flag_ptr(other.stop_flag_ptr), + result_mutex(), + result_cv(), + result(std::move(other.result)), + completed(other.completed.load()) {} +}; + +class WhisperTaskQueue { +public: + WhisperTaskQueue(struct whisper_context* ctx, size_t n_workers = 2) + : ctx_(ctx), stop_flag_(false) { + printf("Creating task queue with %zu workers\n", n_workers); + for (size_t i = 0; i < n_workers; ++i) { + workers_.emplace_back(&WhisperTaskQueue::workerLoop, this, i); + } + } + + ~WhisperTaskQueue() { + shutdown(); + } + + WhisperTaskResult executeTask(WhisperTask&& task) { + printf("Queueing task: %s\n", task.filename.c_str()); + + // Create task on heap and move data into it + auto task_ptr = std::make_shared(); + task_ptr->pcmf32 = std::move(task.pcmf32); + task_ptr->pcmf32s = std::move(task.pcmf32s); + task_ptr->params = std::move(task.params); + task_ptr->filename = std::move(task.filename); + task_ptr->request_ptr = task.request_ptr; + task_ptr->stop_flag_ptr = &stop_flag_; + + { + std::lock_guard lock(queue_mutex_); + if (stop_flag_.load()) { + printf("Task queue is shutting down, rejecting task: %s\n", task_ptr->filename.c_str()); + WhisperTaskResult result; + result.success = false; + result.error_msg = "{\"error\":\"Task queue is shutting down\"}"; + result.status_code = 503; + return result; + } + tasks_.push(task_ptr); + printf("Task queued, queue size: %zu\n", tasks_.size()); + } + queue_cv_.notify_one(); + + printf("Waiting for task completion: %s\n", task_ptr->filename.c_str()); + // Wait for completion + std::unique_lock lock(task_ptr->result_mutex); + task_ptr->result_cv.wait(lock, [&]{ return task_ptr->completed.load(); }); + + printf("Task completed: %s\n", task_ptr->filename.c_str()); + return task_ptr->result; + } + + void shutdown() { + { + std::lock_guard lock(queue_mutex_); + if (stop_flag_.load()) { + return; + } + printf("[shutdown] Initiating task queue shutdown.\n"); + stop_flag_ = true; + + // Drain the queue and notify waiting clients + printf("[shutdown] Draining %zu pending tasks...\n", tasks_.size()); + while (!tasks_.empty()) { + auto task_ptr = tasks_.front(); + tasks_.pop(); + printf("[shutdown] Rejecting pending task for file: %s\n", task_ptr->filename.c_str()); + + WhisperTaskResult result; + result.success = false; + result.error_msg = "{\"error\":\"Server is shutting down\"}"; + result.status_code = 503; // Service Unavailable + + // Notify completion + { + std::lock_guard res_lock(task_ptr->result_mutex); + task_ptr->result = std::move(result); + task_ptr->completed.store(true); + } + task_ptr->result_cv.notify_one(); + } + } + queue_cv_.notify_all(); + + printf("[shutdown] Waiting for worker threads to join.\n"); + for (auto& worker : workers_) { + if (worker.joinable()) { + worker.join(); + } + } + printf("[shutdown] All workers joined.\n"); + } + +private: + void workerLoop(size_t worker_id) { + printf("Worker %zu started\n", worker_id); + while (true) { + std::shared_ptr task; + + { + std::unique_lock lock(queue_mutex_); + queue_cv_.wait(lock, [&]{ return stop_flag_.load() || !tasks_.empty(); }); + + if (stop_flag_.load()) { + printf("Worker %zu received shutdown signal and is terminating.\n", worker_id); + return; + } + + task = tasks_.front(); + tasks_.pop(); + printf("Worker %zu picked up task: %s\n", worker_id, task->filename.c_str()); + } + + // Process the task with mutex protection for whisper context + WhisperTaskResult result; + { + std::lock_guard ctx_lock(whisper_mutex_); + printf("Worker %zu processing task: %s\n", worker_id, task->filename.c_str()); + result = processWhisperTask(*task); + printf("Worker %zu completed task: %s\n", worker_id, task->filename.c_str()); + } + + // Notify completion + { + std::lock_guard lock(task->result_mutex); + task->result = std::move(result); + task->completed.store(true); + } + task->result_cv.notify_one(); + } + } + + WhisperTaskResult processWhisperTask(WhisperTask& task) { + WhisperTaskResult result; + + try { + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + task.params.n_threads*task.params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); + } + + // print some info about the processing + { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", + __func__, task.filename.c_str(), int(task.pcmf32.size()), float(task.pcmf32.size())/WHISPER_SAMPLE_RATE, + task.params.n_threads, task.params.n_processors, + task.params.language.c_str(), + task.params.translate ? "translate" : "transcribe", + task.params.tinydiarize ? "tdrz = 1, " : "", + task.params.no_timestamps ? 0 : 1); + fprintf(stderr, "\n"); + } + + // run the inference + { + printf("Running whisper.cpp inference on %s\n", task.filename.c_str()); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.strategy = task.params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = task.params.print_progress; + wparams.print_timestamps = !task.params.no_timestamps; + wparams.print_special = task.params.print_special; + wparams.translate = task.params.translate; + wparams.language = task.params.language.c_str(); + wparams.detect_language = task.params.detect_language; + wparams.n_threads = task.params.n_threads; + wparams.n_max_text_ctx = task.params.max_context >= 0 ? task.params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = task.params.offset_t_ms; + wparams.duration_ms = task.params.duration_ms; + + wparams.thold_pt = task.params.word_thold; + wparams.max_len = task.params.max_len == 0 ? 60 : task.params.max_len; + wparams.split_on_word = task.params.split_on_word; + wparams.audio_ctx = task.params.audio_ctx; + + wparams.debug_mode = task.params.debug_mode; + wparams.tdrz_enable = task.params.tinydiarize; + wparams.initial_prompt = task.params.prompt.c_str(); + + wparams.greedy.best_of = task.params.best_of; + wparams.beam_search.beam_size = task.params.beam_size; + + wparams.temperature = task.params.temperature; + wparams.no_speech_thold = task.params.no_speech_thold; + wparams.temperature_inc = task.params.temperature_inc; + wparams.entropy_thold = task.params.entropy_thold; + wparams.logprob_thold = task.params.logprob_thold; + + wparams.no_timestamps = task.params.no_timestamps; + wparams.token_timestamps = !task.params.no_timestamps && task.params.response_format == vjson_format; + wparams.no_context = task.params.no_context; + wparams.suppress_nst = task.params.suppress_nst; + + whisper_print_user_data user_data = { &task.params, &task.pcmf32s, 0 }; + + // this callback is called on each new segment + if (task.params.print_realtime) { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; + } + + if (wparams.print_progress) { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; + } + + // tell whisper to abort if the HTTP connection closed + if (task.request_ptr) { + wparams.abort_callback = [](void *user_data) -> bool { + auto& task = *static_cast(user_data); + if (task.stop_flag_ptr != nullptr && task.stop_flag_ptr->load()) { + task.abort_reason.store(AbortReason::ServerShutdown); + return true; + } + if (task.request_ptr->is_connection_closed()) { + task.abort_reason.store(AbortReason::ClientDisconnect); + return true; + } + return false; + }; + wparams.abort_callback_user_data = &task; + } + + if (whisper_full_parallel(ctx_, wparams, task.pcmf32.data(), task.pcmf32.size(), task.params.n_processors) != 0) { + // handle failure or early abort + AbortReason reason = task.abort_reason.load(); + if (reason == AbortReason::ClientDisconnect) { + fprintf(stderr, "client disconnected, aborted processing\n"); + result.success = false; + result.error_msg = "{\"error\":\"client disconnected\"}"; + result.status_code = 499; // Client Closed Request (nginx convention) + return result; + } else if (reason == AbortReason::ServerShutdown) { + fprintf(stderr, "server shutting down, aborted processing\n"); + result.success = false; + result.error_msg = "{\"error\":\"server is shutting down\"}"; + result.status_code = 503; // Service Unavailable + return result; + } + + result.success = false; + result.error_msg = "{\"error\":\"failed to process audio\"}"; + result.status_code = 500; + return result; + } + } + + // format output based on response format + if (task.params.response_format == text_format) { + result.output = output_str(ctx_, task.params, task.pcmf32s); + result.content_type = "text/html; charset=utf-8"; + } + else if (task.params.response_format == srt_format) { + std::stringstream ss; + const int n_segments = whisper_full_n_segments(ctx_); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx_, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx_, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx_, i); + std::string speaker = ""; + + if (task.params.diarize && task.pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(task.pcmf32s, t0, t1); + } + + ss << i + 1 + task.params.offset_n << "\n"; + ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; + ss << speaker << text << "\n\n"; + } + result.output = ss.str(); + result.content_type = "application/x-subrip"; + } + else if (task.params.response_format == vtt_format) { + std::stringstream ss; + ss << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx_); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx_, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx_, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx_, i); + std::string speaker = ""; + + if (task.params.diarize && task.pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(task.pcmf32s, t0, t1, true); + speaker.insert(0, ""); + } + + ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + ss << speaker << text << "\n\n"; + } + result.output = ss.str(); + result.content_type = "text/vtt"; + } + else if (task.params.response_format == vjson_format) { + std::string results = output_str(ctx_, task.params, task.pcmf32s); + std::vector lang_probs(whisper_lang_max_id() + 1, 0.0f); + const auto detected_lang_id = whisper_lang_auto_detect(ctx_, 0, task.params.n_threads, lang_probs.data()); + json jres = json{ + {"task", task.params.translate ? "translate" : "transcribe"}, + {"language", whisper_lang_str_full(whisper_full_lang_id(ctx_))}, + {"duration", float(task.pcmf32.size())/WHISPER_SAMPLE_RATE}, + {"text", results}, + {"segments", json::array()}, + {"detected_language", whisper_lang_str_full(detected_lang_id)}, + {"detected_language_probability", lang_probs[detected_lang_id]}, + {"language_probabilities", json::object()} + }; + + for (int i = 0; i <= whisper_lang_max_id(); ++i) { + if (lang_probs[i] > 0.001f) { + jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i]; + } + } + + const int n_segments = whisper_full_n_segments(ctx_); + for (int i = 0; i < n_segments; ++i) { + json segment = json{ + {"id", i}, + {"text", whisper_full_get_segment_text(ctx_, i)}, + }; + + if (!task.params.no_timestamps) { + segment["start"] = whisper_full_get_segment_t0(ctx_, i) * 0.01; + segment["end"] = whisper_full_get_segment_t1(ctx_, i) * 0.01; + } + + float total_logprob = 0; + const int n_tokens = whisper_full_n_tokens(ctx_, i); + for (int j = 0; j < n_tokens; ++j) { + whisper_token_data token = whisper_full_get_token_data(ctx_, i, j); + if (token.id >= whisper_token_eot(ctx_)) { + continue; + } + + segment["tokens"].push_back(token.id); + json word = json{{"word", whisper_full_get_token_text(ctx_, i, j)}}; + if (!task.params.no_timestamps) { + word["start"] = token.t0 * 0.01; + word["end"] = token.t1 * 0.01; + word["t_dtw"] = token.t_dtw; + } + word["probability"] = token.p; + total_logprob += token.plog; + segment["words"].push_back(word); + } + + segment["temperature"] = task.params.temperature; + segment["avg_logprob"] = total_logprob / n_tokens; + segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx_, i); + + jres["segments"].push_back(segment); + } + result.output = jres.dump(-1, ' ', false, json::error_handler_t::replace); + result.content_type = "application/json"; + } + else { + std::string results = output_str(ctx_, task.params, task.pcmf32s); + json jres = json{{"text", results}}; + result.output = jres.dump(-1, ' ', false, json::error_handler_t::replace); + result.content_type = "application/json"; + } + + result.success = true; + result.status_code = 200; + } + catch (const std::exception& e) { + result.success = false; + result.error_msg = "{\"error\":\"" + std::string(e.what()) + "\"}"; + result.status_code = 500; + } + + return result; + } + + struct whisper_context* ctx_; + std::vector workers_; + std::queue> tasks_; + std::mutex queue_mutex_; + std::condition_variable queue_cv_; + std::mutex whisper_mutex_; // Protect whisper context access + std::atomic stop_flag_; +}; + int main(int argc, char ** argv) { ggml_backend_load_all(); whisper_params params; server_params sparams; - std::mutex whisper_mutex; - if (whisper_params_parse(argc, argv, params, sparams) == false) { whisper_print_usage(argc, argv, params, sparams); return 1; @@ -695,6 +1137,9 @@ int main(int argc, char ** argv) { whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); state.store(SERVER_STATE_READY); + // Create task queue with configurable worker threads + std::unique_ptr task_queue = std::make_unique(ctx, sparams.workers); + svr->set_default_headers({{"Server", "whisper.cpp"}, {"Access-Control-Allow-Origin", "*"}, @@ -784,9 +1229,6 @@ int main(int argc, char ** argv) { }); svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){ - // acquire whisper model mutex lock - std::lock_guard lock(whisper_mutex); - // first check user requested fields of the request if (!req.has_file("file")) { @@ -797,8 +1239,11 @@ int main(int argc, char ** argv) { } auto audio_file = req.get_file_value("file"); + // Create local copy of params for this request + whisper_params request_params = default_params; + // check non-required fields - get_req_parameters(req, params); + get_req_parameters(req, request_params); std::string filename{audio_file.filename}; printf("Received request: %s\n", filename.c_str()); @@ -823,7 +1268,7 @@ int main(int argc, char ** argv) { } // read audio content into pcmf32 - if (!::read_audio_data(temp_filename, pcmf32, pcmf32s, params.diarize)) + if (!::read_audio_data(temp_filename, pcmf32, pcmf32s, request_params.diarize)) { fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; @@ -834,7 +1279,7 @@ int main(int argc, char ** argv) { // remove temp file std::remove(temp_filename.c_str()); } else { - if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, params.diarize)) + if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, request_params.diarize)) { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; @@ -845,262 +1290,38 @@ int main(int argc, char ** argv) { printf("Successfully loaded %s\n", filename.c_str()); - // print system information - { - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); - } - - // print some info about the processing - { - fprintf(stderr, "\n"); - if (!whisper_is_multilingual(ctx)) { - if (params.language != "en" || params.translate) { - params.language = "en"; - params.translate = false; - fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); - } - } - if (params.detect_language) { - params.language = "auto"; + // Handle multilingual models + if (!whisper_is_multilingual(ctx)) { + if (request_params.language != "en" || request_params.translate) { + request_params.language = "en"; + request_params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", - __func__, filename.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, - params.n_threads, params.n_processors, - params.language.c_str(), - params.translate ? "translate" : "transcribe", - params.tinydiarize ? "tdrz = 1, " : "", - params.no_timestamps ? 0 : 1); - - fprintf(stderr, "\n"); } - - // run the inference - { - printf("Running whisper.cpp inference on %s\n", filename.c_str()); - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - - wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; - - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.detect_language = params.detect_language; - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.thold_pt = params.word_thold; - wparams.max_len = params.max_len == 0 ? 60 : params.max_len; - wparams.split_on_word = params.split_on_word; - wparams.audio_ctx = params.audio_ctx; - - wparams.debug_mode = params.debug_mode; - - wparams.tdrz_enable = params.tinydiarize; // [TDRZ] - - wparams.initial_prompt = params.prompt.c_str(); - - wparams.greedy.best_of = params.best_of; - wparams.beam_search.beam_size = params.beam_size; - - wparams.temperature = params.temperature; - wparams.no_speech_thold = params.no_speech_thold; - wparams.temperature_inc = params.temperature_inc; - wparams.entropy_thold = params.entropy_thold; - wparams.logprob_thold = params.logprob_thold; - - wparams.no_timestamps = params.no_timestamps; - wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; - wparams.no_context = params.no_context; - - wparams.suppress_nst = params.suppress_nst; - - wparams.vad = params.vad; - wparams.vad_model_path = params.vad_model.c_str(); - - wparams.vad_params.threshold = params.vad_threshold; - wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms; - wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms; - wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s; - wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms; - wparams.vad_params.samples_overlap = params.vad_samples_overlap; - - whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; - - // this callback is called on each new segment - if (params.print_realtime) { - wparams.new_segment_callback = whisper_print_segment_callback; - wparams.new_segment_callback_user_data = &user_data; - } - - if (wparams.print_progress) { - wparams.progress_callback = whisper_print_progress_callback; - wparams.progress_callback_user_data = &user_data; - } - - // tell whisper to abort if the HTTP connection closed - wparams.abort_callback = [](void *user_data) { - // user_data is a pointer to our Request - auto req_ptr = static_cast(user_data); - return req_ptr->is_connection_closed(); - }; - wparams.abort_callback_user_data = (void*)&req; - - if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { - // handle failure or early abort - if (req.is_connection_closed()) { - // log client disconnect - fprintf(stderr, "client disconnected, aborted processing\n"); - res.status = 499; // Client Closed Request (nginx convention) - res.set_content("{\"error\":\"client disconnected\"}", "application/json"); - return; - } - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - res.status = 500; // Internal Server Error - const std::string error_resp = "{\"error\":\"failed to process audio\"}"; - res.set_content(error_resp, "application/json"); - return; - } - } - - // return results to user - if (params.response_format == text_format) - { - std::string results = output_str(ctx, params, pcmf32s); - res.set_content(results.c_str(), "text/html; charset=utf-8"); + if (request_params.detect_language) { + request_params.language = "auto"; } - else if (params.response_format == srt_format) - { - std::stringstream ss; - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } - - ss << i + 1 + params.offset_n << "\n"; - ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; - ss << speaker << text << "\n\n"; - } - res.set_content(ss.str(), "application/x-subrip"); - } else if (params.response_format == vtt_format) { - std::stringstream ss; - - ss << "WEBVTT\n\n"; - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); - speaker.insert(0, ""); - } - ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; - ss << speaker << text << "\n\n"; - } - res.set_content(ss.str(), "text/vtt"); - } else if (params.response_format == vjson_format) { - /* try to match openai/whisper's Python format */ - std::string results = output_str(ctx, params, pcmf32s); - // Get language probabilities - std::vector lang_probs(whisper_lang_max_id() + 1, 0.0f); - const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data()); - json jres = json{ - {"task", params.translate ? "translate" : "transcribe"}, - {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, - {"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE}, - {"text", results}, - {"segments", json::array()}, - {"detected_language", whisper_lang_str_full(detected_lang_id)}, - {"detected_language_probability", lang_probs[detected_lang_id]}, - {"language_probabilities", json::object()} - }; - // Add all language probabilities - for (int i = 0; i <= whisper_lang_max_id(); ++i) { - if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities - jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i]; - } - } - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) - { - json segment = json{ - {"id", i}, - {"text", whisper_full_get_segment_text(ctx, i)}, - }; - - if (!params.no_timestamps) { - segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; - segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; - } - - float total_logprob = 0; - const int n_tokens = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n_tokens; ++j) { - whisper_token_data token = whisper_full_get_token_data(ctx, i, j); - if (token.id >= whisper_token_eot(ctx)) { - continue; - } - - segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; - if (!params.no_timestamps) { - word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; - word["t_dtw"] = token.t_dtw; - } - word["probability"] = token.p; - total_logprob += token.plog; - segment["words"].push_back(word); - } - - segment["temperature"] = params.temperature; - segment["avg_logprob"] = total_logprob / n_tokens; - - // TODO compression_ratio and no_speech_prob are not implemented yet - // segment["compression_ratio"] = 0; - segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i); - - jres["segments"].push_back(segment); - } - res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), - "application/json"); - } - // TODO add more output formats - else - { - std::string results = output_str(ctx, params, pcmf32s); - json jres = json{ - {"text", results} - }; - res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), - "application/json"); + // Create task and submit to queue + WhisperTask task; + task.pcmf32 = std::move(pcmf32); + task.pcmf32s = std::move(pcmf32s); + task.params = request_params; + task.filename = filename; + task.request_ptr = &req; // For abort callback + + // Execute task (blocks until completion) + WhisperTaskResult result = task_queue->executeTask(std::move(task)); + + // Set response + res.status = result.status_code; + if (result.success) { + res.set_content(result.output, result.content_type); + } else { + res.set_content(result.error_msg, "application/json"); } - - // reset params to their defaults - params = default_params; }); svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ - std::lock_guard lock(whisper_mutex); state.store(SERVER_STATE_LOADING_MODEL); if (!req.has_file("model")) { @@ -1118,6 +1339,9 @@ int main(int argc, char ** argv) { return; } + // shutdown task queue + task_queue.reset(); + // clean up whisper_free(ctx); @@ -1133,6 +1357,9 @@ int main(int argc, char ** argv) { // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + // recreate task queue with new context + task_queue = std::make_unique(ctx, sparams.workers); + state.store(SERVER_STATE_READY); const std::string success = "Load was successful!"; res.set_content(success, "application/text"); @@ -1193,6 +1420,7 @@ int main(int argc, char ** argv) { shutdown_handler = [&](int signal) { printf("\nCaught signal %d, shutting down gracefully...\n", signal); + task_queue->shutdown(); if (svr) { svr->stop(); } @@ -1214,6 +1442,7 @@ int main(int argc, char ** argv) { // clean up function, to be called before exit auto clean_up = [&]() { + task_queue.reset(); whisper_print_timings(ctx); whisper_free(ctx); };