Skip to content

Commit 4c60e6b

Browse files
committed
Wrong implementation of carry_initial_prompt
1 parent 53c9a3a commit 4c60e6b

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

include/whisper.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -482,13 +482,14 @@ extern "C" {
482482
int duration_ms; // audio duration to process in ms
483483

484484
bool translate;
485-
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
486-
bool no_timestamps; // do not generate timestamps
487-
bool single_segment; // force single segment output (useful for streaming)
488-
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
489-
bool print_progress; // print progress information
490-
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
491-
bool print_timestamps; // print timestamps for each text segment when printing realtime
485+
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
486+
bool carry_initial_prompt; // carry the initial prompt to the next call
487+
bool no_timestamps; // do not generate timestamps
488+
bool single_segment; // force single segment output (useful for streaming)
489+
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
490+
bool print_progress; // print progress information
491+
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
492+
bool print_timestamps; // print timestamps for each text segment when printing realtime
492493

493494
// [EXPERIMENTAL] token-level timestamps
494495
bool token_timestamps; // enable token-level timestamps

src/whisper.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4646,14 +4646,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
46464646
/*.offset_ms =*/ 0,
46474647
/*.duration_ms =*/ 0,
46484648

4649-
/*.translate =*/ false,
4650-
/*.no_context =*/ true,
4651-
/*.no_timestamps =*/ false,
4652-
/*.single_segment =*/ false,
4653-
/*.print_special =*/ false,
4654-
/*.print_progress =*/ true,
4655-
/*.print_realtime =*/ false,
4656-
/*.print_timestamps =*/ true,
4649+
/*.translate =*/ false,
4650+
/*.no_context =*/ true,
4651+
/*.carry_initial_prompt =*/ false,
4652+
/*.no_timestamps =*/ false,
4653+
/*.single_segment =*/ false,
4654+
/*.print_special =*/ false,
4655+
/*.print_progress =*/ true,
4656+
/*.print_realtime =*/ false,
4657+
/*.print_timestamps =*/ true,
46574658

46584659
/*.token_timestamps =*/ false,
46594660
/*.thold_pt =*/ 0.01f,
@@ -5454,6 +5455,7 @@ int whisper_full_with_state(
54545455
prompt_tokens.resize(-n_needed);
54555456
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
54565457
}
5458+
remaining_prompt_length = ctx->model.hparams.n_text_ctx / 2 - 1 - initial_prompt_tokens.size();
54575459
prompt_tokens.resize(n_needed);
54585460
params.prompt_tokens = prompt_tokens.data();
54595461
params.prompt_n_tokens = prompt_tokens.size();
@@ -5610,9 +5612,21 @@ int whisper_full_with_state(
56105612
// init prompt and kv cache for the current iteration
56115613
// TODO: do not recompute the prompt if it is the same as previous time
56125614
{
5613-
prompt.clear();
5615+
// LLMs think we should add this if block here
5616+
if (params.carry_initial_prompt) {
5617+
// Prepend initial_prompt_tokens to the prompt
5618+
int nignored = std::max((int)initial_prompt_tokens.size(), prompt_past.size());
5619+
std::vector<whisper_token> remaining_prompt(prompt_past.begin() + nignored, prompt_past.end());
5620+
remaining_prompt.resize(std::min(remaining_prompt.size(), remaining_prompt_length));
5621+
prompt.clear();
5622+
prompt.insert(prompt.end(), initial_prompt_tokens.begin(), initial_prompt_tokens.end());
5623+
prompt.insert(prompt.end(), remaining_prompt.begin(), remaining_prompt.end());
5624+
} else {
5625+
prompt.clear();
5626+
}
56145627

56155628
// if we have already generated some text, use it as a prompt to condition the next generation
5629+
// But maybe we can put it here?
56165630
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
56175631
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
56185632

0 commit comments

Comments
 (0)