Skip to content

Commit 0a84581

Browse files
committed
Make stream more test-friendly
Run `stream --test-pipe --no-vt100 2>/dev/null < pcmf32.raw` to get nearly-reproducible results. If you want to do a strict testing, use `--no-timestamps` as well. ``` cat jfk.raw | ./build/bin/stream -m models/ggml-large-v2.bin --step 2000 --test-pipe -no-vt100 2>/dev/null ( And so my fellow Americans...) ( And so my fellow Americans, ask...) ( And so my fellow Americans, ask not what your country will give you, but what your country will give you.) [00:00:00.000 --> 00:00:30.000] And so my fellow Americans, ask not what your country can do for you. ( Ask what you can do for your) [00:00:02.360 --> 00:00:32.360] Ask what you can do for your country. ``` VAD: ``` cat jfk.raw | ./build/bin/stream -m models/ggml-large-v2.bin --step -2000 --test-pipe -no-vt100 2>/dev/null [00:00:00.000 --> 00:00:03.000] And so, my fellow Americans. [00:00:00.000 --> 00:00:07.920] Ask not what your country can do for you, ask what you can do for your country. ```
1 parent 425d3ad commit 0a84581

File tree

1 file changed

+107
-59
lines changed

1 file changed

+107
-59
lines changed

examples/stream/stream.cpp

Lines changed: 107 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ struct whisper_params {
7373
bool use_gpu = true;
7474
bool flash_attn = false;
7575
bool interim = false;
76+
bool delete_vt100 = true;
77+
bool test_pipe = false;
7678

7779
std::string language = "en";
7880
std::string model = "models/ggml-base.en.bin";
@@ -111,6 +113,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
111113
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
112114
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
113115
else if (arg == "-int" || arg == "--interim") { params.interim = true; }
116+
else if (arg == "-nvt" || arg == "--no-vt100") { params.delete_vt100 = false; }
117+
else if ( arg == "--test-pipe") { params.test_pipe = true; }
114118

115119
else {
116120
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -150,6 +154,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
150154
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
151155
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
152156
fprintf(stderr, " -int, --interim [%-7s] show interim report in vad every step\n", params.interim ? "true" : "false");
157+
fprintf(stderr, " -nvt, --no-vt100 [%-7s] do not delete unconfirmed result\n", params.delete_vt100 ? "false" : "true");
158+
fprintf(stderr, " --test-pipe [%-7s] use all data from pipe\n", params.test_pipe ? "true" : "false");
153159
fprintf(stderr, "\n");
154160
}
155161

@@ -160,8 +166,8 @@ int main(int argc, char ** argv) {
160166
return 1;
161167
}
162168

163-
params.keep_ms = std::min(params.keep_ms, params.step_ms);
164-
params.length_ms = std::max(params.length_ms, params.step_ms);
169+
params.keep_ms = std::min(params.keep_ms, abs(params.step_ms));
170+
params.length_ms = std::max(params.length_ms, abs(params.step_ms));
165171

166172
const int n_samples_step = (1e-3*abs(params.step_ms))*WHISPER_SAMPLE_RATE;
167173
const int n_samples_len = (1e-3*params.length_ms )*WHISPER_SAMPLE_RATE;
@@ -269,7 +275,7 @@ int main(int argc, char ** argv) {
269275

270276
// ignore premature stdin
271277
int n_mod = 0;
272-
if (piped) {
278+
if (piped && !params.test_pipe) {
273279
const auto n_bytes_len = sizeof(float) * n_samples_len;
274280
setStdinNonBlocking();
275281
while (true) {
@@ -349,9 +355,6 @@ int main(int argc, char ** argv) {
349355
}
350356
}
351357
pcmf32.resize(n_bytes_read / sizeof(float));
352-
if (!is_running) {
353-
break;
354-
}
355358
} else if (t_diff < abs(params.step_ms)) {
356359
std::this_thread::sleep_for(std::chrono::milliseconds(abs(params.step_ms) - t_diff));
357360
continue;
@@ -371,15 +374,21 @@ int main(int argc, char ** argv) {
371374
}
372375

373376
n_samples_new += n_samples_buf;
374-
if (!is_interim && n_samples_new > 2*n_samples_step) {
377+
if (!use_vad && n_samples_new > 2*n_samples_step) {
375378
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n", __func__);
376379
fprintf(stderr, "t_diff = %.2fs, new = %.2fs, buf = %.2fs\n\n", 1e-3*t_diff, float(n_samples_new)/WHISPER_SAMPLE_RATE, float(n_samples_buf)/WHISPER_SAMPLE_RATE);
377380
n_samples_old = 0;
378381
n_samples_new = 0;
379382
t_last = t_now;
380383
continue;
381384
}
385+
386+
if (n_samples_old + n_samples_new == 0) {
387+
continue;
388+
}
389+
382390
is_interim = false;
391+
bool is_aborted = true;
383392

384393
if (!use_vad){
385394
n_samples_old += n_samples_new;
@@ -389,11 +398,17 @@ int main(int argc, char ** argv) {
389398

390399
t_last = t_now;
391400
} else {
392-
pcmf32.resize(n_samples_step);
393-
copy(pcmf32_deque.end() - n_samples_step, pcmf32_deque.end(), pcmf32.begin());
394-
if (::vad_simple(pcmf32, WHISPER_SAMPLE_RATE, std::min(1000, abs(params.step_ms) / 2), params.vad_thold, params.freq_thold, false)) {
395-
pcmf32.resize(n_samples_old + n_samples_new);
396-
copy(pcmf32_deque.end() - n_samples_old - n_samples_new, pcmf32_deque.end(), pcmf32.begin());
401+
const auto n_samples = std::min(n_samples_len, n_samples_old + n_samples_new);
402+
403+
is_aborted = (n_samples > n_samples_len);
404+
if (is_running && !is_aborted) {
405+
pcmf32.resize(n_samples_step);
406+
copy(pcmf32_deque.end() - n_samples_step, pcmf32_deque.end(), pcmf32.begin());
407+
}
408+
409+
if (!is_running || is_aborted || ::vad_simple(pcmf32, WHISPER_SAMPLE_RATE, std::min(1000, abs(params.step_ms) / 2), params.vad_thold, params.freq_thold, false)) {
410+
pcmf32.resize(n_samples);
411+
copy(pcmf32_deque.end() - n_samples, pcmf32_deque.end(), pcmf32.begin());
397412
n_samples_new = 0;
398413
n_samples_old = 0;
399414

@@ -443,25 +458,50 @@ int main(int argc, char ** argv) {
443458
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
444459
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
445460

446-
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
447-
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
448-
return 6;
461+
{
462+
auto pcm_size = pcmf32.size();
463+
if (pcm_size < WHISPER_SAMPLE_RATE * 1.1) {
464+
pcmf32.resize(pcm_size + WHISPER_SAMPLE_RATE, 0.0f);
465+
}
466+
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
467+
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
468+
return 6;
469+
}
470+
pcmf32.resize(pcm_size);
449471
}
450472
t_interim = std::chrono::high_resolution_clock::now();
451473

452474
// print result;
453475
int n_segments;
454-
bool is_unconfirmed = false;
476+
bool no_confirmed = (!use_vad && n_samples_old < n_samples_len - n_samples_step);
455477
std::ostringstream text;
456478
{
457-
if (!use_vad || params.interim && params.no_timestamps && s_to_delete.size()) {
479+
if (params.delete_vt100 && s_to_delete.size()) {
458480
printf("\33[2K\r");
459481

460482
// print long empty line to clear the previous line
461483
printf("%s", std::string(s_to_delete.size(), ' ').c_str());
462484

463485
printf("\33[2K\r");
464-
} else if (use_vad && !params.no_timestamps) {
486+
}
487+
s_to_delete.clear();
488+
489+
n_segments = whisper_full_n_segments(ctx);
490+
no_confirmed = (no_confirmed || is_interim && n_segments <= 1);
491+
if (is_running && is_interim && !no_confirmed) {
492+
const int64_t t1_ms = whisper_full_get_segment_t1(ctx, n_segments - 2) * 10;
493+
if (t1_ms < abs(params.step_ms)) {
494+
// too short to confirm
495+
no_confirmed = true;
496+
} else {
497+
t_last += std::chrono::milliseconds(t1_ms);
498+
const auto n_samples_confirmed = (1e-3*t1_ms)*WHISPER_SAMPLE_RATE;
499+
pcmf32.resize(n_samples_confirmed); // for timestamps
500+
n_samples_old -= n_samples_confirmed;
501+
}
502+
}
503+
504+
if (use_vad && !params.no_timestamps && (!is_running || !no_confirmed)) {
465505
const int64_t t1 = (t_last - t_start).count()/1000000;
466506
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
467507

@@ -470,28 +510,42 @@ int main(int argc, char ** argv) {
470510
text << std::endl;
471511
}
472512

473-
n_segments = whisper_full_n_segments(ctx);
474-
if (is_interim) {
475-
if (n_segments < 2) {
476-
is_unconfirmed = true;
477-
} else {
478-
n_segments--;
479-
const int64_t t1_ms = whisper_full_get_segment_t1(ctx, n_segments - 1) * 10;
480-
t_last += std::chrono::milliseconds(t1_ms);
481-
const auto n_confirmed = (1e-3*t1_ms)*WHISPER_SAMPLE_RATE;
482-
pcmf32.resize(n_confirmed);
483-
n_samples_old -= n_confirmed;
484-
}
485-
}
486513
for (int i = 0; i < n_segments; ++i) {
487514
std::string i_text = whisper_full_get_segment_text(ctx, i);
488515

489-
if (!use_vad || params.no_timestamps) {
516+
// last segment may be s_to_delete
517+
if (i == n_segments - 1 && is_running && (no_confirmed || is_interim)) {
518+
if (params.no_timestamps && i > 0) {
519+
text << std::endl;
520+
}
521+
if (is_interim) {
522+
// utf-8 cannot be simply cut into two
523+
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
524+
const auto t_u32 = conv.from_bytes(i_text);
525+
const auto t_sub = conv.to_bytes(t_u32.substr(0, t_u32.size() * 0.7));
526+
i_text = t_sub + "";
527+
}
528+
if (s_to_delete.size() > 0) {
529+
s_to_delete += " ";
530+
}
531+
s_to_delete += i_text;
532+
if (!params.delete_vt100) {
533+
s_to_delete = "(" + s_to_delete + ")";
534+
}
535+
break;
536+
}
537+
538+
if (is_running && no_confirmed) {
539+
if (s_to_delete.size() > 0) {
540+
s_to_delete += " ";
541+
}
542+
s_to_delete += i_text;
543+
} else if (params.no_timestamps) {
490544
if (i > 0) {
491545
text << std::endl;
492546
}
493547
text << i_text;
494-
} else {
548+
} else if (!is_running || !(is_interim && i == n_segments - 1)) {
495549
const int64_t t_end = (t_last - t_start).count()/1000000;
496550
const int64_t t_beg = std::max(0.0, t_end - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
497551
const int64_t t0 = t_beg/10 + whisper_full_get_segment_t0(ctx, i);
@@ -507,10 +561,13 @@ int main(int argc, char ** argv) {
507561
}
508562
}
509563

510-
if (use_vad && !params.no_timestamps) {
564+
if (use_vad && !params.no_timestamps && (!is_running || !no_confirmed)) {
511565
text << std::endl;
512566
text << "### Transcription " << n_iter << " END";
513567
text << std::endl;
568+
if (s_to_delete.size() > 0) {
569+
text << std::endl;
570+
}
514571
}
515572
}
516573

@@ -519,42 +576,33 @@ int main(int argc, char ** argv) {
519576
fout << std::endl;
520577
}
521578

522-
++n_iter;
523-
524-
if (is_unconfirmed) {
525-
--n_iter;
526-
// utf-8 cannot be simply cut into two
527-
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
528-
auto t_u32 = conv.from_bytes(text.str());
529-
auto t_sub = conv.to_bytes(t_u32.substr(0, t_u32.size() / 2));
530-
text.str(t_sub + "");
579+
if (!no_confirmed) {
580+
++n_iter;
531581
}
532582

533583
printf("%s", text.str().c_str());
534584

535-
if (is_unconfirmed || !use_vad && n_samples_old < n_samples_len - n_samples_step) {
536-
s_to_delete = text.str();
585+
if (is_running && (no_confirmed || is_interim)) {
586+
printf("%s%s", s_to_delete.c_str(), params.delete_vt100 ? "" : "\n");
587+
--n_segments; // exclude s_to_delete from context
537588
} else {
538589
printf("\n");
539590
s_to_delete = "";
540591

541-
if (!use_vad) {
542-
n_iter = 0;
543-
if (n_samples_keep < n_samples_old) {
544-
// keep part of the audio for next iteration to try to mitigate word boundary issues
545-
n_samples_old = n_samples_keep;
546-
}
592+
if (is_aborted) {
593+
// keep part of the audio for next iteration to try to mitigate word boundary issues
594+
n_samples_old = std::min(n_samples_old, n_samples_keep);
547595
}
596+
}
548597

549-
// Add tokens of the last full length segment as the prompt
550-
if (!params.no_context) {
551-
prompt_tokens.clear();
598+
// Add tokens of the last full length segment as the prompt
599+
if (!no_confirmed && !params.no_context) {
600+
prompt_tokens.clear();
552601

553-
for (int i = 0; i < n_segments; ++i) {
554-
const int token_count = whisper_full_n_tokens(ctx, i);
555-
for (int j = 0; j < token_count; ++j) {
556-
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
557-
}
602+
for (int i = 0; i < n_segments; ++i) {
603+
const int token_count = whisper_full_n_tokens(ctx, i);
604+
for (int j = 0; j < token_count; ++j) {
605+
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
558606
}
559607
}
560608
}

0 commit comments

Comments
 (0)