88#include " whisper.h"
99
1010#include < cassert>
11+ #include < chrono>
1112#include < cstdio>
13+ #include < fstream>
1214#include < string>
1315#include < thread>
1416#include < vector>
15- # include < fstream >
17+
1618#include < signal.h>
1719#include < sndfile.h>
1820
@@ -36,14 +38,14 @@ struct whisper_params {
3638 double keep_s = 0.0 ;
3739 int32_t capture_id = 10 ;
3840 int32_t max_tokens = 32 ;
39- int32_t audio_ctx = 768 ;
41+ int32_t audio_ctx = 0 ;
42+ int32_t beam_size = -1 ;
4043
41- float vad_thold = 0 .5f ;
42- float freq_thold = 200 .0f ;
44+ float vad_thold = 0 .6f ;
45+ float freq_thold = 100 .0f ;
4346
44- bool speed_up = false ;
4547 bool translate = false ;
46- bool no_fallback = true ;
48+ bool no_fallback = false ;
4749 bool print_special = false ;
4850 bool no_context = true ;
4951 bool no_timestamps = true ;
@@ -75,9 +77,9 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
7577 else if (arg == " -c" || arg == " --capture" ) { params.capture_id = std::stoi (argv[++i]); }
7678 else if (arg == " -mt" || arg == " --max-tokens" ) { params.max_tokens = std::stoi (argv[++i]); }
7779 else if (arg == " -ac" || arg == " --audio-ctx" ) { params.audio_ctx = std::stoi (argv[++i]); }
80+ else if (arg == " -bs" || arg == " --beam-size" ) { params.beam_size = std::stoi (argv[++i]); }
7881 else if (arg == " -vth" || arg == " --vad-thold" ) { params.vad_thold = std::stof (argv[++i]); }
7982 else if (arg == " -fth" || arg == " --freq-thold" ) { params.freq_thold = std::stof (argv[++i]); }
80- else if (arg == " -su" || arg == " --speed-up" ) { params.speed_up = true ; }
8183 else if (arg == " -tr" || arg == " --translate" ) { params.translate = true ; }
8284 else if (arg == " -nf" || arg == " --no-fallback" ) { params.no_fallback = true ; }
8385 else if (arg == " -ps" || arg == " --print-special" ) { params.print_special = true ; }
@@ -114,9 +116,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
114116 fprintf (stderr, " -c ID, --capture ID [%-7d] capture device ID\n " , params.capture_id );
115117 fprintf (stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n " , params.max_tokens );
116118 fprintf (stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n " , params.audio_ctx );
119+ fprintf (stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n " , params.beam_size );
117120 fprintf (stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n " , params.vad_thold );
118121 fprintf (stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n " , params.freq_thold );
119- fprintf (stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n " , params.speed_up ? " true" : " false" );
120122 fprintf (stderr, " -tr, --translate [%-7s] translate from source language to english\n " , params.translate ? " true" : " false" );
121123 fprintf (stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n " , params.no_fallback ? " true" : " false" );
122124 fprintf (stderr, " -ps, --print-special [%-7s] print special tokens\n " , params.print_special ? " true" : " false" );
@@ -133,12 +135,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
133135}
134136
135137int main (int argc, char ** argv) {
138+ ggml_backend_load_all ();
136139
137140 signal (SIGINT, exit_handler);
138141
139142 whisper_params params;
140143
141- if (whisper_params_parse (argc, argv, params) == false ) {
144+ if (whisper_params_parse (argc, argv, params) == false ) {
142145 return 1 ;
143146 }
144147
@@ -149,6 +152,7 @@ int main(int argc, char ** argv) {
149152 const int n_samples_len = (int )((params.length_s )*WHISPER_SAMPLE_RATE);
150153 const int n_samples_keep = (int )((params.keep_s )*WHISPER_SAMPLE_RATE);
151154 const int n_samples_30s = (int )((30.0 )*WHISPER_SAMPLE_RATE);
155+
152156 const int vadleast_n_samples_len = n_samples_len;
153157
154158 const bool use_vad = n_samples_step <= 0 ; // sliding window mode uses VAD
@@ -317,7 +321,7 @@ int main(int argc, char ** argv) {
317321
318322 // run the inference
319323 {
320- whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
324+ whisper_full_params wparams = whisper_full_default_params (params. beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY);
321325
322326 wparams.print_progress = false ;
323327 wparams.print_special = params.print_special ;
@@ -328,15 +332,15 @@ int main(int argc, char ** argv) {
328332 wparams.max_tokens = params.max_tokens ;
329333 wparams.language = params.language .c_str ();
330334 wparams.n_threads = params.n_threads ;
335+ wparams.beam_search .beam_size = params.beam_size ;
331336
332337 wparams.audio_ctx = params.audio_ctx ;
333- wparams.speed_up = params.speed_up ;
334338
335339 wparams.tdrz_enable = params.tinydiarize ; // [TDRZ]
336340
337341 // disable temperature fallback
338342 wparams.temperature_inc = -1 .0f ;
339- // wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
343+ wparams.temperature_inc = params.no_fallback ? 0 .0f : wparams.temperature_inc ;
340344
341345 wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data ();
342346 wparams.prompt_n_tokens = params.no_context ? 0 : (int )prompt_tokens.size ();
@@ -397,6 +401,7 @@ int main(int argc, char ** argv) {
397401
398402 std::cout << output;
399403 fflush (stdout);
404+
400405 if (params.fname_out .length () > 0 ) {
401406 fout << output;
402407 }
@@ -448,7 +453,5 @@ int main(int argc, char ** argv) {
448453 whisper_print_timings (ctx);
449454 whisper_free (ctx);
450455
451-
452-
453456 return 0 ;
454457}
0 commit comments