Skip to content

Commit ada8c05

Browse files
committed
update code.
1 parent 2c46fe1 commit ada8c05

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ if (WHISPER_portaudio)
218218
endif (WHISPER_portaudio)
219219
#add_subdirectory(server)
220220
#add_subdirectory(quantize)
221-
#add_subdirectory(vad-speech-segments)
221+
add_subdirectory(vad-speech-segments)
222222
if (WHISPER_SDL2)
223223
#add_subdirectory(stream)
224224
#add_subdirectory(command)

examples/stream-portaudio/stream-portaudio.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
#include "whisper.h"
99

1010
#include <cassert>
11+
#include <chrono>
1112
#include <cstdio>
13+
#include <fstream>
1214
#include <string>
1315
#include <thread>
1416
#include <vector>
15-
#include <fstream>
17+
1618
#include <signal.h>
1719
#include <sndfile.h>
1820

@@ -36,14 +38,14 @@ struct whisper_params {
3638
double keep_s = 0.0;
3739
int32_t capture_id = 10;
3840
int32_t max_tokens = 32;
39-
int32_t audio_ctx = 768;
41+
int32_t audio_ctx = 0;
42+
int32_t beam_size = -1;
4043

41-
float vad_thold = 0.5f;
42-
float freq_thold = 200.0f;
44+
float vad_thold = 0.6f;
45+
float freq_thold = 100.0f;
4346

44-
bool speed_up = false;
4547
bool translate = false;
46-
bool no_fallback = true;
48+
bool no_fallback = false;
4749
bool print_special = false;
4850
bool no_context = true;
4951
bool no_timestamps = true;
@@ -75,9 +77,9 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
7577
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
7678
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
7779
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
80+
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
7881
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
7982
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
80-
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
8183
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
8284
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
8385
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
@@ -114,9 +116,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
114116
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
115117
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
116118
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
119+
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
117120
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
118121
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
119-
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
120122
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
121123
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
122124
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
@@ -133,12 +135,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
133135
}
134136

135137
int main(int argc, char ** argv) {
138+
ggml_backend_load_all();
136139

137140
signal(SIGINT, exit_handler);
138141

139142
whisper_params params;
140143

141-
if (whisper_params_parse(argc, argv, params) == false) {
144+
if (whisper_params_parse(argc, argv, params) == false) {
142145
return 1;
143146
}
144147

@@ -149,6 +152,7 @@ int main(int argc, char ** argv) {
149152
const int n_samples_len = (int)((params.length_s)*WHISPER_SAMPLE_RATE);
150153
const int n_samples_keep = (int)((params.keep_s )*WHISPER_SAMPLE_RATE);
151154
const int n_samples_30s = (int)((30.0 )*WHISPER_SAMPLE_RATE);
155+
152156
const int vadleast_n_samples_len = n_samples_len;
153157

154158
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
@@ -317,7 +321,7 @@ int main(int argc, char ** argv) {
317321

318322
// run the inference
319323
{
320-
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
324+
whisper_full_params wparams = whisper_full_default_params(params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY);
321325

322326
wparams.print_progress = false;
323327
wparams.print_special = params.print_special;
@@ -328,15 +332,15 @@ int main(int argc, char ** argv) {
328332
wparams.max_tokens = params.max_tokens;
329333
wparams.language = params.language.c_str();
330334
wparams.n_threads = params.n_threads;
335+
wparams.beam_search.beam_size = params.beam_size;
331336

332337
wparams.audio_ctx = params.audio_ctx;
333-
wparams.speed_up = params.speed_up;
334338

335339
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
336340

337341
// disable temperature fallback
338342
wparams.temperature_inc = -1.0f;
339-
//wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
343+
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
340344

341345
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
342346
wparams.prompt_n_tokens = params.no_context ? 0 : (int)prompt_tokens.size();
@@ -397,6 +401,7 @@ int main(int argc, char ** argv) {
397401

398402
std::cout << output;
399403
fflush(stdout);
404+
400405
if (params.fname_out.length() > 0) {
401406
fout << output;
402407
}
@@ -448,7 +453,5 @@ int main(int argc, char ** argv) {
448453
whisper_print_timings(ctx);
449454
whisper_free(ctx);
450455

451-
452-
453456
return 0;
454457
}

0 commit comments

Comments
 (0)