Skip to content

Commit 289946d

Browse files
committed
Simplify stream's pcmf32 handling
Use one deque instead of two vectors (old and new). Old and new are length variables now. Basically: Get `step - new` samples every time. Then substitute `new = (around) step;` The new audio data is simply appended to the deque. (Limit the deque size to 30 seconds.) Pass `old + new` samples to whisper inference. If the data has been consumed, let `old = 0; new = 0;` If some of the data should be kept for the next iter, `old = keep;` If you want to get only N samples next time, `new = step - N;` In VAD mode: `stream --interim --step -3000` will Get 3000ms of audio. Run `vad_simple(step_ms)`. If nothing is detected, get 100ms more audio and retry. If nothing is detected and 3000ms has been passed, go into the interim mode, where `n_segments - 1` segments will be confirmed. (`old -= confirmed_t1`) If `n_segments == 1`, only show the first half of the result. Misc: Increase the default `max_tokens` because 32 is too small for 10 seconds. (Some Japanese speech was garbled.) Write wav as soon as the data is available. `no_timestamps` is the default even for VAD because it is more useful to show to the hard-of-hearing
1 parent 419aee3 commit 289946d

File tree

1 file changed

+146
-100
lines changed

1 file changed

+146
-100
lines changed

examples/stream/stream.cpp

Lines changed: 146 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "whisper.h"
88

99
#include <cassert>
10+
#include <codecvt>
1011
#include <cstdio>
1112
#include <string>
1213
#include <thread>
@@ -21,7 +22,7 @@ struct whisper_params {
2122
int32_t length_ms = 10000;
2223
int32_t keep_ms = 200;
2324
int32_t capture_id = -1;
24-
int32_t max_tokens = 32;
25+
int32_t max_tokens = 128;
2526
int32_t audio_ctx = 0;
2627

2728
float vad_thold = 0.6f;
@@ -36,6 +37,7 @@ struct whisper_params {
3637
bool save_audio = false; // save audio to wav file
3738
bool use_gpu = true;
3839
bool flash_attn = false;
40+
bool interim = false;
3941

4042
std::string language = "en";
4143
std::string model = "models/ggml-base.en.bin";
@@ -65,13 +67,15 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
6567
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
6668
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
6769
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
70+
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
6871
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
6972
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
7073
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
7174
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
7275
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
7376
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
7477
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
78+
else if (arg == "-int" || arg == "--interim") { params.interim = true; }
7579

7680
else {
7781
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -102,13 +106,15 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
102106
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
103107
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
104108
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
109+
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
105110
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
106111
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
107112
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
108113
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
109114
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
110115
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
111116
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
117+
fprintf(stderr, " -int, --interim [%-7s] show interim report in vad every step\n", params.interim ? "true" : "false");
112118
fprintf(stderr, "\n");
113119
}
114120

@@ -122,19 +128,16 @@ int main(int argc, char ** argv) {
122128
params.keep_ms = std::min(params.keep_ms, params.step_ms);
123129
params.length_ms = std::max(params.length_ms, params.step_ms);
124130

125-
const int n_samples_step = (1e-3*params.step_ms )*WHISPER_SAMPLE_RATE;
126-
const int n_samples_len = (1e-3*params.length_ms)*WHISPER_SAMPLE_RATE;
127-
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
128-
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
131+
const int n_samples_step = (1e-3*abs(params.step_ms))*WHISPER_SAMPLE_RATE;
132+
const int n_samples_len = (1e-3*params.length_ms )*WHISPER_SAMPLE_RATE;
133+
const int n_samples_keep = (1e-3*params.keep_ms )*WHISPER_SAMPLE_RATE;
134+
const int n_samples_30s = (1e-3*30000.0 )*WHISPER_SAMPLE_RATE;
135+
const int n_samples_100ms= (1e-3*100.0 )*WHISPER_SAMPLE_RATE;
129136

130-
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
137+
const bool use_vad = params.step_ms <= 0; // sliding window mode uses VAD
131138

132139
const int n_new_line = !use_vad ? std::max(1, params.length_ms / params.step_ms - 1) : 1; // number of steps to print new line
133140

134-
params.no_timestamps = !use_vad;
135-
params.no_context |= use_vad;
136-
params.max_tokens = 0;
137-
138141
// init audio
139142

140143
audio_async audio(params.length_ms);
@@ -159,9 +162,10 @@ int main(int argc, char ** argv) {
159162

160163
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
161164

162-
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
163-
std::vector<float> pcmf32_old;
164-
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
165+
std::vector<float> pcmf32(n_samples_30s, 0.0f);
166+
std::deque<float> pcmf32_deque;
167+
int n_samples_new = 0;
168+
int n_samples_old = 0;
165169

166170
std::vector<whisper_token> prompt_tokens;
167171

@@ -219,17 +223,17 @@ int main(int argc, char ** argv) {
219223

220224
wavWriter.open(filename, WHISPER_SAMPLE_RATE, 16, 1);
221225
}
222-
printf("[Start speaking]\n");
223-
fflush(stdout);
226+
fprintf(stderr, "[Start speaking]\n");
227+
fflush(stderr);
224228

225229
auto t_last = std::chrono::high_resolution_clock::now();
230+
auto t_interim = t_last;
231+
bool is_interim = false;
226232
const auto t_start = t_last;
233+
std::string s_to_delete = "";
227234

228235
// main audio loop
229236
while (is_running) {
230-
if (params.save_audio) {
231-
wavWriter.write(pcmf32_new.data(), pcmf32_new.size());
232-
}
233237
// handle Ctrl + C
234238
is_running = sdl_poll_events();
235239

@@ -238,61 +242,74 @@ int main(int argc, char ** argv) {
238242
}
239243

240244
// process new audio
245+
const auto t_now = std::chrono::high_resolution_clock::now();
246+
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
247+
248+
// get new audio
249+
if (n_samples_new > n_samples_step) {
250+
pcmf32.clear();
251+
} else if (t_diff < abs(params.step_ms)) {
252+
std::this_thread::sleep_for(std::chrono::milliseconds(abs(params.step_ms) - t_diff));
253+
continue;
254+
} else {
255+
audio.next(pcmf32);
256+
}
241257

242-
if (!use_vad) {
243-
while (true) {
244-
audio.next(pcmf32_new);
245-
246-
if ((int) pcmf32_new.size() > 2*n_samples_step) {
247-
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
248-
audio.clear();
249-
continue;
250-
}
251-
252-
if ((int) pcmf32_new.size() >= n_samples_step) {
253-
break;
254-
}
255-
256-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
257-
}
258-
259-
const int n_samples_new = pcmf32_new.size();
260-
261-
// take up to params.length_ms audio from previous iteration
262-
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
258+
const int n_samples_buf = pcmf32.size();
263259

264-
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
260+
if (params.save_audio && n_samples_buf > 0) {
261+
wavWriter.write(pcmf32.data(), n_samples_buf);
262+
}
265263

266-
pcmf32.resize(n_samples_new + n_samples_take);
264+
copy(pcmf32.begin(), pcmf32.end(), back_inserter(pcmf32_deque));
265+
if (pcmf32_deque.size() > n_samples_30s) {
266+
pcmf32_deque.erase(pcmf32_deque.begin(), pcmf32_deque.end() - n_samples_30s);
267+
}
267268

268-
for (int i = 0; i < n_samples_take; i++) {
269-
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
270-
}
269+
n_samples_new += n_samples_buf;
270+
if (!is_interim && n_samples_new > 2*n_samples_step) {
271+
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n", __func__);
272+
fprintf(stderr, "t_diff = %.2fs, new = %.2fs, buf = %.2fs\n\n", 1e-3*t_diff, float(n_samples_new)/WHISPER_SAMPLE_RATE, float(n_samples_buf)/WHISPER_SAMPLE_RATE);
273+
n_samples_old = 0;
274+
n_samples_new = 0;
275+
t_last = t_now;
276+
continue;
277+
}
278+
is_interim = false;
271279

272-
memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float));
280+
if (!use_vad){
281+
n_samples_old += n_samples_new;
282+
n_samples_new = 0;
283+
pcmf32.resize(n_samples_old);
284+
copy(pcmf32_deque.end() - n_samples_old, pcmf32_deque.end(), pcmf32.begin());
273285

274-
pcmf32_old = pcmf32;
286+
t_last = t_now;
275287
} else {
276-
const auto t_now = std::chrono::high_resolution_clock::now();
277-
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
278-
279-
if (t_diff < 2000) {
280-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
281-
282-
continue;
283-
}
284-
285-
audio.get(2000, pcmf32_new);
286-
287-
if (::vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
288-
audio.get(params.length_ms, pcmf32);
288+
pcmf32.resize(n_samples_step);
289+
copy(pcmf32_deque.end() - n_samples_step, pcmf32_deque.end(), pcmf32.begin());
290+
if (::vad_simple(pcmf32, WHISPER_SAMPLE_RATE, std::min(1000, abs(params.step_ms) / 2), params.vad_thold, params.freq_thold, false)) {
291+
pcmf32.resize(n_samples_old + n_samples_new);
292+
copy(pcmf32_deque.end() - n_samples_old - n_samples_new, pcmf32_deque.end(), pcmf32.begin());
293+
n_samples_new = 0;
294+
n_samples_old = 0;
295+
296+
t_last = t_now;
289297
} else {
290-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
291-
292-
continue;
298+
const auto n_interim_diff_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_interim).count();
299+
300+
if (params.interim && n_interim_diff_ms > abs(params.step_ms)) {
301+
is_interim = (n_interim_diff_ms < params.length_ms - abs(params.step_ms));
302+
n_samples_old += n_samples_new;
303+
n_samples_new = 0;
304+
pcmf32.resize(n_samples_old);
305+
copy(pcmf32_deque.end() - n_samples_old, pcmf32_deque.end(), pcmf32.begin());
306+
} else {
307+
n_samples_new -= n_samples_100ms;
308+
n_samples_old = std::min(n_samples_len, n_samples_old + n_samples_100ms);
309+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
310+
continue;
311+
}
293312
}
294-
295-
t_last = t_now;
296313
}
297314

298315
// run the inference
@@ -324,80 +341,109 @@ int main(int argc, char ** argv) {
324341
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
325342
return 6;
326343
}
344+
t_interim = std::chrono::high_resolution_clock::now();
327345

328346
// print result;
347+
int n_segments;
348+
bool is_unconfirmed = false;
349+
std::ostringstream text;
329350
{
330-
if (!use_vad) {
351+
if (!use_vad || params.interim && params.no_timestamps && s_to_delete.size()) {
331352
printf("\33[2K\r");
332353

333354
// print long empty line to clear the previous line
334-
printf("%s", std::string(100, ' ').c_str());
355+
printf("%s", std::string(s_to_delete.size(), ' ').c_str());
335356

336357
printf("\33[2K\r");
337-
} else {
358+
} else if (use_vad && !params.no_timestamps) {
338359
const int64_t t1 = (t_last - t_start).count()/1000000;
339360
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
340361

341-
printf("\n");
342-
printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1);
343-
printf("\n");
362+
text << std::endl;
363+
text << "### Transcription " << n_iter << " START | t0 = " << t0 << " ms | t1 = " << t1 << " ms" << std::endl;
364+
text << std::endl;
344365
}
345366

346-
const int n_segments = whisper_full_n_segments(ctx);
367+
n_segments = whisper_full_n_segments(ctx);
368+
if (is_interim) {
369+
if (n_segments < 2) {
370+
is_unconfirmed = true;
371+
} else {
372+
n_segments--;
373+
const int64_t t1_ms = whisper_full_get_segment_t1(ctx, n_segments - 1) * 10;
374+
t_last += std::chrono::milliseconds(t1_ms);
375+
const auto n_confirmed = (1e-3*t1_ms)*WHISPER_SAMPLE_RATE;
376+
pcmf32.resize(n_confirmed);
377+
n_samples_old -= n_confirmed;
378+
}
379+
}
347380
for (int i = 0; i < n_segments; ++i) {
348-
const char * text = whisper_full_get_segment_text(ctx, i);
349-
350-
if (params.no_timestamps) {
351-
printf("%s", text);
352-
fflush(stdout);
381+
std::string i_text = whisper_full_get_segment_text(ctx, i);
353382

354-
if (params.fname_out.length() > 0) {
355-
fout << text;
383+
if (!use_vad || params.no_timestamps) {
384+
if (i > 0) {
385+
text << std::endl;
356386
}
387+
text << i_text;
357388
} else {
358-
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
359-
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
389+
const int64_t t_end = (t_last - t_start).count()/1000000;
390+
const int64_t t_beg = std::max(0.0, t_end - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
391+
const int64_t t0 = t_beg/10 + whisper_full_get_segment_t0(ctx, i);
392+
const int64_t t1 = t_beg/10 + whisper_full_get_segment_t1(ctx, i);
360393

361-
std::string output = "[" + to_timestamp(t0, false) + " --> " + to_timestamp(t1, false) + "] " + text;
394+
text << "[" << to_timestamp(t0, false) << " --> " << to_timestamp(t1, false) << "] " << i_text;
362395

363396
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
364-
output += " [SPEAKER_TURN]";
397+
text << " [SPEAKER_TURN]";
365398
}
366399

367-
output += "\n";
368-
369-
printf("%s", output.c_str());
370-
fflush(stdout);
371-
372-
if (params.fname_out.length() > 0) {
373-
fout << output;
374-
}
400+
text << std::endl;
375401
}
376402
}
377403

378-
if (params.fname_out.length() > 0) {
379-
fout << std::endl;
404+
if (use_vad && !params.no_timestamps) {
405+
text << std::endl;
406+
text << "### Transcription " << n_iter << " END";
407+
text << std::endl;
380408
}
409+
}
381410

382-
if (use_vad) {
383-
printf("\n");
384-
printf("### Transcription %d END\n", n_iter);
385-
}
411+
if (params.fname_out.length() > 0) {
412+
fout << text.str();
413+
fout << std::endl;
386414
}
387415

388416
++n_iter;
389417

390-
if (!use_vad && (n_iter % n_new_line) == 0) {
418+
if (is_unconfirmed) {
419+
--n_iter;
420+
// utf-8 cannot be simply cut into two
421+
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
422+
auto t_u32 = conv.from_bytes(text.str());
423+
auto t_sub = conv.to_bytes(t_u32.substr(0, t_u32.size() / 2));
424+
text.str(t_sub + "");
425+
}
426+
427+
printf("%s", text.str().c_str());
428+
429+
if (is_unconfirmed || !use_vad && n_samples_old < n_samples_len - n_samples_step) {
430+
s_to_delete = text.str();
431+
} else {
391432
printf("\n");
433+
s_to_delete = "";
392434

393-
// keep part of the audio for next iteration to try to mitigate word boundary issues
394-
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
435+
if (!use_vad) {
436+
n_iter = 0;
437+
if (n_samples_keep < n_samples_old) {
438+
// keep part of the audio for next iteration to try to mitigate word boundary issues
439+
n_samples_old = n_samples_keep;
440+
}
441+
}
395442

396443
// Add tokens of the last full length segment as the prompt
397444
if (!params.no_context) {
398445
prompt_tokens.clear();
399446

400-
const int n_segments = whisper_full_n_segments(ctx);
401447
for (int i = 0; i < n_segments; ++i) {
402448
const int token_count = whisper_full_n_tokens(ctx, i);
403449
for (int j = 0; j < token_count; ++j) {

0 commit comments

Comments
 (0)