diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 37b23886821..6d496f39e91 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include // command-line parameters struct whisper_params { @@ -37,6 +39,7 @@ struct whisper_params { bool save_audio = false; // save audio to wav file bool use_gpu = true; bool flash_attn = false; + bool pausable = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -74,6 +77,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-p" || arg == "--pausable") { params.pausable = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -112,6 +116,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -p, --pausable [%-7s] allow stdin commands p,r (PAUSE)/(RESUME)\n", params.pausable ? "true" : "false"); fprintf(stderr, "\n"); } @@ -206,7 +211,9 @@ int main(int argc, char ** argv) { int n_iter = 0; - bool is_running = true; + std::atomic_bool is_running(true); + std::atomic_bool is_paused(false); + std::atomic_int control_state(0); // 1 - pause, 2 - resume std::ofstream fout; if (params.fname_out.length() > 0) { @@ -231,6 +238,26 @@ int main(int argc, char ** argv) { printf("[Start speaking]\n"); fflush(stdout); + std::thread control_thread; + if (params.pausable) { + control_thread = std::thread([&]() { + std::string line; + while (is_running) { + if (!std::getline(std::cin, line)) { + break; + } + + if (line == "p") { + control_state = 1; + } else if (line == "r") { + control_state = 2; + } else { + fprintf(stderr, "[ERROR] Only 'p' (pause), 'r' (resume) accepted]\n"); + } + } + }); + } + auto t_last = std::chrono::high_resolution_clock::now(); const auto t_start = t_last; @@ -246,6 +273,35 @@ int main(int argc, char ** argv) { break; } + if (params.pausable) { + int st = control_state.exchange(0); + if (st == 1 && !is_paused) { + audio.clear(); + audio.pause(); + + params.no_context = true; + + pcmf32.clear(); + pcmf32_new.clear(); + pcmf32_old.clear(); + prompt_tokens.clear(); + is_paused = true; + whisper_reset_timings(ctx); + } else if (st == 2 && is_paused) { + audio.resume(); + audio.clear(); + whisper_reset_timings(ctx); + is_paused = false; + t_last = std::chrono::high_resolution_clock::now(); + } + + if (is_paused) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + t_last = std::chrono::high_resolution_clock::now(); + continue; + } + } + // process new audio if (!use_vad) { @@ -431,5 +487,8 @@ int main(int argc, char ** argv) { whisper_print_timings(ctx); whisper_free(ctx); + if (control_thread.joinable()) + control_thread.join(); + return 0; }