8
8
#include " whisper.h"
9
9
10
10
#include < cassert>
11
+ #include < chrono>
11
12
#include < cstdio>
13
+ #include < fstream>
12
14
#include < string>
13
15
#include < thread>
14
16
#include < vector>
15
- # include < fstream >
17
+
16
18
#include < signal.h>
17
19
#include < sndfile.h>
18
20
@@ -36,14 +38,14 @@ struct whisper_params {
36
38
double keep_s = 0.0 ;
37
39
int32_t capture_id = 10 ;
38
40
int32_t max_tokens = 32 ;
39
- int32_t audio_ctx = 768 ;
41
+ int32_t audio_ctx = 0 ;
42
+ int32_t beam_size = -1 ;
40
43
41
- float vad_thold = 0 .5f ;
42
- float freq_thold = 200 .0f ;
44
+ float vad_thold = 0 .6f ;
45
+ float freq_thold = 100 .0f ;
43
46
44
- bool speed_up = false ;
45
47
bool translate = false ;
46
- bool no_fallback = true ;
48
+ bool no_fallback = false ;
47
49
bool print_special = false ;
48
50
bool no_context = true ;
49
51
bool no_timestamps = true ;
@@ -75,9 +77,9 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
75
77
else if (arg == " -c" || arg == " --capture" ) { params.capture_id = std::stoi (argv[++i]); }
76
78
else if (arg == " -mt" || arg == " --max-tokens" ) { params.max_tokens = std::stoi (argv[++i]); }
77
79
else if (arg == " -ac" || arg == " --audio-ctx" ) { params.audio_ctx = std::stoi (argv[++i]); }
80
+ else if (arg == " -bs" || arg == " --beam-size" ) { params.beam_size = std::stoi (argv[++i]); }
78
81
else if (arg == " -vth" || arg == " --vad-thold" ) { params.vad_thold = std::stof (argv[++i]); }
79
82
else if (arg == " -fth" || arg == " --freq-thold" ) { params.freq_thold = std::stof (argv[++i]); }
80
- else if (arg == " -su" || arg == " --speed-up" ) { params.speed_up = true ; }
81
83
else if (arg == " -tr" || arg == " --translate" ) { params.translate = true ; }
82
84
else if (arg == " -nf" || arg == " --no-fallback" ) { params.no_fallback = true ; }
83
85
else if (arg == " -ps" || arg == " --print-special" ) { params.print_special = true ; }
@@ -114,9 +116,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
114
116
fprintf (stderr, " -c ID, --capture ID [%-7d] capture device ID\n " , params.capture_id );
115
117
fprintf (stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n " , params.max_tokens );
116
118
fprintf (stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n " , params.audio_ctx );
119
+ fprintf (stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n " , params.beam_size );
117
120
fprintf (stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n " , params.vad_thold );
118
121
fprintf (stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n " , params.freq_thold );
119
- fprintf (stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n " , params.speed_up ? " true" : " false" );
120
122
fprintf (stderr, " -tr, --translate [%-7s] translate from source language to english\n " , params.translate ? " true" : " false" );
121
123
fprintf (stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n " , params.no_fallback ? " true" : " false" );
122
124
fprintf (stderr, " -ps, --print-special [%-7s] print special tokens\n " , params.print_special ? " true" : " false" );
@@ -133,12 +135,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
133
135
}
134
136
135
137
int main (int argc, char ** argv) {
138
+ ggml_backend_load_all ();
136
139
137
140
signal (SIGINT, exit_handler);
138
141
139
142
whisper_params params;
140
143
141
- if (whisper_params_parse (argc, argv, params) == false ) {
144
+ if (whisper_params_parse (argc, argv, params) == false ) {
142
145
return 1 ;
143
146
}
144
147
@@ -149,6 +152,7 @@ int main(int argc, char ** argv) {
149
152
const int n_samples_len = (int )((params.length_s )*WHISPER_SAMPLE_RATE);
150
153
const int n_samples_keep = (int )((params.keep_s )*WHISPER_SAMPLE_RATE);
151
154
const int n_samples_30s = (int )((30.0 )*WHISPER_SAMPLE_RATE);
155
+
152
156
const int vadleast_n_samples_len = n_samples_len;
153
157
154
158
const bool use_vad = n_samples_step <= 0 ; // sliding window mode uses VAD
@@ -317,7 +321,7 @@ int main(int argc, char ** argv) {
317
321
318
322
// run the inference
319
323
{
320
- whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
324
+ whisper_full_params wparams = whisper_full_default_params (params. beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY);
321
325
322
326
wparams.print_progress = false ;
323
327
wparams.print_special = params.print_special ;
@@ -328,15 +332,15 @@ int main(int argc, char ** argv) {
328
332
wparams.max_tokens = params.max_tokens ;
329
333
wparams.language = params.language .c_str ();
330
334
wparams.n_threads = params.n_threads ;
335
+ wparams.beam_search .beam_size = params.beam_size ;
331
336
332
337
wparams.audio_ctx = params.audio_ctx ;
333
- wparams.speed_up = params.speed_up ;
334
338
335
339
wparams.tdrz_enable = params.tinydiarize ; // [TDRZ]
336
340
337
341
// disable temperature fallback
338
342
wparams.temperature_inc = -1 .0f ;
339
- // wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
343
+ wparams.temperature_inc = params.no_fallback ? 0 .0f : wparams.temperature_inc ;
340
344
341
345
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data ();
342
346
wparams.prompt_n_tokens = params.no_context ? 0 : (int )prompt_tokens.size ();
@@ -397,6 +401,7 @@ int main(int argc, char ** argv) {
397
401
398
402
std::cout << output;
399
403
fflush (stdout);
404
+
400
405
if (params.fname_out .length () > 0 ) {
401
406
fout << output;
402
407
}
@@ -448,7 +453,5 @@ int main(int argc, char ** argv) {
448
453
whisper_print_timings (ctx);
449
454
whisper_free (ctx);
450
455
451
-
452
-
453
456
return 0 ;
454
457
}
0 commit comments