diff --git a/bindings/go/params.go b/bindings/go/params.go index 95c5bfaf934..d8dee57e331 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) { p.print_timestamps = toBool(v) } + // Set language id func (p *Params) SetLanguage(lang int) error { if lang == -1 { @@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) { p.initial_prompt = C.CString(prompt) } +func (p *Params) SetCarryInitialPrompt(v bool) { + p.carry_initial_prompt = toBool(v) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS @@ -199,6 +204,9 @@ func (p *Params) String() string { if p.token_timestamps { str += " token_timestamps" } + if p.carry_initial_prompt { + str += " carry_initial_prompt" + } return str + ">" } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 498ff126037..76ce80fb4cc 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -157,6 +157,8 @@ public void tdrzEnable(boolean enable) { /** Tokens to provide to the whisper decoder as an initial prompt. * These are prepended to any existing text context from a previous call. */ public String initial_prompt; + /** Always prepend initial_prompt for every decode chunk. */ + public CBool carry_initial_prompt; /** Prompt tokens. (int*) */ public Pointer prompt_tokens; @@ -336,8 +338,8 @@ protected List getFieldOrder() { "no_timestamps", "single_segment", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "thold_pt", "thold_ptsum", "max_len", - "split_on_word", "max_tokens", "debug_mode", "audio_ctx", - "tdrz_enable", "suppress_regex", "initial_prompt", + "split_on_word", "max_tokens", "debug_mode", "audio_ctx", + "tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 882c68d042f..70417cb1664 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -26,7 +26,7 @@ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); -#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36 +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37 extern VALUE cParams; extern VALUE cVADParams; @@ -46,6 +46,7 @@ static ID id_print_special; static ID id_print_progress; static ID id_print_realtime; static ID id_print_timestamps; +static ID id_carry_initial_prompt; static ID id_suppress_blank; static ID id_suppress_nst; static ID id_token_timestamps; @@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self) { BOOL_PARAMS_GETTER(self, print_timestamps) } + +/* + * call-seq: + * carry_initial_prompt -> true or false + */ +static VALUE +ruby_whisper_params_get_carry_initial_prompt(VALUE self) +{ + BOOL_PARAMS_GETTER(self, carry_initial_prompt) +} + +/* + * call-seq: + * carry_initial_prompt = bool -> bool + */ +static VALUE +ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, carry_initial_prompt, value) +} /* * call-seq: * suppress_blank = force_suppress -> force_suppress @@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) SET_PARAM_IF_SAME(max_len) SET_PARAM_IF_SAME(split_on_word) SET_PARAM_IF_SAME(initial_prompt) + SET_PARAM_IF_SAME(carry_initial_prompt) SET_PARAM_IF_SAME(offset) SET_PARAM_IF_SAME(duration) SET_PARAM_IF_SAME(max_text_tokens) @@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper) DEFINE_PARAM(max_len, 11) DEFINE_PARAM(split_on_word, 12) DEFINE_PARAM(initial_prompt, 13) - DEFINE_PARAM(diarize, 14) - DEFINE_PARAM(offset, 15) - DEFINE_PARAM(duration, 16) - DEFINE_PARAM(max_text_tokens, 17) - DEFINE_PARAM(temperature, 18) - DEFINE_PARAM(max_initial_ts, 19) - DEFINE_PARAM(length_penalty, 20) - DEFINE_PARAM(temperature_inc, 21) - DEFINE_PARAM(entropy_thold, 22) - DEFINE_PARAM(logprob_thold, 23) - DEFINE_PARAM(no_speech_thold, 24) - DEFINE_PARAM(new_segment_callback, 25) - DEFINE_PARAM(new_segment_callback_user_data, 26) - DEFINE_PARAM(progress_callback, 27) - DEFINE_PARAM(progress_callback_user_data, 28) - DEFINE_PARAM(encoder_begin_callback, 29) - DEFINE_PARAM(encoder_begin_callback_user_data, 30) - DEFINE_PARAM(abort_callback, 31) - DEFINE_PARAM(abort_callback_user_data, 32) - DEFINE_PARAM(vad, 33) - DEFINE_PARAM(vad_model_path, 34) - DEFINE_PARAM(vad_params, 35) + DEFINE_PARAM(carry_initial_prompt, 14) + DEFINE_PARAM(diarize, 15) + DEFINE_PARAM(offset, 16) + DEFINE_PARAM(duration, 17) + DEFINE_PARAM(max_text_tokens, 18) + DEFINE_PARAM(temperature, 19) + DEFINE_PARAM(max_initial_ts, 20) + DEFINE_PARAM(length_penalty, 21) + DEFINE_PARAM(temperature_inc, 22) + DEFINE_PARAM(entropy_thold, 23) + DEFINE_PARAM(logprob_thold, 24) + DEFINE_PARAM(no_speech_thold, 25) + DEFINE_PARAM(new_segment_callback, 26) + DEFINE_PARAM(new_segment_callback_user_data, 27) + DEFINE_PARAM(progress_callback, 28) + DEFINE_PARAM(progress_callback_user_data, 29) + DEFINE_PARAM(encoder_begin_callback, 30) + DEFINE_PARAM(encoder_begin_callback_user_data, 31) + DEFINE_PARAM(abort_callback, 32) + DEFINE_PARAM(abort_callback_user_data, 33) + DEFINE_PARAM(vad, 34) + DEFINE_PARAM(vad_model_path, 35) + DEFINE_PARAM(vad_params, 36) rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 0489432a249..d5905dd7037 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -138,6 +138,7 @@ module Whisper ?max_len: Integer, ?split_on_word: boolish, ?initial_prompt: string | nil, + ?carry_initial_prompt: boolish, ?diarize: boolish, ?offset: Integer, ?duration: Integer, @@ -236,6 +237,7 @@ module Whisper def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS + def carry_initial_prompt=: (boolish) -> boolish # Tokens to provide to the whisper decoder as initial prompt # these are prepended to any existing text context from a previous call @@ -243,6 +245,7 @@ module Whisper # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). # def initial_prompt: () -> (String | nil) + def carry_initial_prompt: () -> (true | false) def diarize=: (boolish) -> boolish diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index d5c5d140e8c..4dd9780de7d 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -16,6 +16,7 @@ class TestParams < TestBase :max_len, :split_on_word, :initial_prompt, + :carry_initial_prompt, :diarize, :offset, :duration, @@ -119,6 +120,13 @@ def test_print_timestamps assert !@params.print_timestamps end + def test_carry_initial_prompt + @params.carry_initial_prompt = true + assert @params.carry_initial_prompt + @params.carry_initial_prompt = false + assert !@params.carry_initial_prompt + end + def test_suppress_blank @params.suppress_blank = true assert @params.suppress_blank diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 457a1ff35c2..f5c0f86cc3e 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -5,6 +5,7 @@ #include "grammar-parser.h" #include +#include #include #include #include @@ -77,6 +78,7 @@ struct whisper_params { bool use_gpu = true; bool flash_attn = true; bool suppress_nst = false; + bool carry_initial_prompt = false; std::string language = "en"; std::string prompt; @@ -145,60 +147,61 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params exit(0); } #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); } - else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); } - else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); } - else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } - else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } - else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } - else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } - else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; } - else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } - else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } - else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; } - else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); } - else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if ( arg == "--print-confidence"){ params.print_confidence= true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); } - else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } - else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; } - else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } - else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } - else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; } - else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } - else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } - else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } - else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; } - else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } - else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; } - else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-ojf" || arg == "--output-json-full") { params.output_jsn_full = params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if ( arg == "--print-confidence") { params.print_confidence= true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; } + else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } + else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; } + else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } + else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } // Voice Activity Detection (VAD) else if ( arg == "--vad") { params.vad = true; } else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; } @@ -224,61 +227,62 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); - fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); - fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); - fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); - fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); - fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); - fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false"); - fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); - fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", ""); - fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); - fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); - fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); - fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); - fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); - fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); - fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); - fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); - fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); - fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); + fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false"); + fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); + fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false"); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input audio file path\n", ""); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false"); @@ -387,7 +391,11 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct const char * text = whisper_full_get_token_text(ctx, i, j); const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); + const int n_colors = (int) k_colors.size(); + int raw_col = (int) (std::pow(p, 3)*float(n_colors)); + if (raw_col < 0) raw_col = 0; + if (raw_col > n_colors - 1) raw_col = n_colors - 1; + const int col = raw_col; printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } @@ -1178,7 +1186,8 @@ int main(int argc, char ** argv) { wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); - wparams.initial_prompt = params.prompt.c_str(); + wparams.initial_prompt = params.prompt.c_str(); + wparams.carry_initial_prompt = params.carry_initial_prompt; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; diff --git a/include/whisper.h b/include/whisper.h index fcd756a9fe2..f4cc6bf7abd 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -525,6 +525,7 @@ extern "C" { // use whisper_tokenize() to convert text to tokens // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224) const char * initial_prompt; + bool carry_initial_prompt; // if true, always prepend initial_prompt to every decode window (may reduce conditioning on previous text) const whisper_token * prompt_tokens; int prompt_n_tokens; diff --git a/src/whisper.cpp b/src/whisper.cpp index 39c53ba233a..c83e995be7b 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -138,6 +138,10 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } while (0) #define WHISPER_MAX_DECODERS 8 + +// temperature below which we condition on past text history +static constexpr float WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF = 0.5f; + #define WHISPER_MAX_NODES 4096 static std::string format(const char * fmt, ...) { @@ -880,7 +884,10 @@ struct whisper_state { std::vector logits; std::vector result_all; - std::vector prompt_past; + + // prompt history split into static prefix (prompt_past0) and dynamic rolling context (prompt_past1) + std::vector prompt_past0; // static carried initial prompt (if enabled) + std::vector prompt_past1; // dynamic context from decoded output int lang_id = 0; // english by default @@ -5920,9 +5927,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /* suppress_regex =*/ nullptr, - /*.initial_prompt =*/ nullptr, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.initial_prompt =*/ nullptr, + /*.carry_initial_prompt =*/ false, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, /*.language =*/ "en", /*.detect_language =*/ false, @@ -6874,17 +6882,19 @@ int whisper_full_with_state( decoder.rng = std::mt19937(j); } - // the accumulated text context so far - auto & prompt_past = state->prompt_past; + // the accumulated text context split into static (prompt_past0) and dynamic (prompt_past1) + auto & prompt_past0 = state->prompt_past0; + auto & prompt_past1 = state->prompt_past1; if (params.no_context) { - prompt_past.clear(); + prompt_past0.clear(); + prompt_past1.clear(); } // prepare prompt { std::vector prompt_tokens; - // initial prompt + // tokenize the initial prompt if (!params.prompt_tokens && params.initial_prompt) { prompt_tokens.resize(1024); int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); @@ -6896,14 +6906,17 @@ int whisper_full_with_state( params.prompt_tokens = prompt_tokens.data(); params.prompt_n_tokens = prompt_tokens.size(); } - - // prepend the prompt tokens to the prompt_past if (params.prompt_tokens && params.prompt_n_tokens > 0) { - // parse tokens from the pointer - for (int i = 0; i < params.prompt_n_tokens; i++) { - prompt_past.push_back(params.prompt_tokens[i]); + if (params.carry_initial_prompt) { + if (prompt_past0.empty()) { + prompt_past0.assign(params.prompt_tokens, params.prompt_tokens + params.prompt_n_tokens); + } + } else { + for (int i = 0; i < params.prompt_n_tokens; ++i) { + prompt_past1.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past1.begin(), prompt_past1.end() - params.prompt_n_tokens, prompt_past1.end()); } - std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); } } @@ -6989,7 +7002,7 @@ int whisper_full_with_state( // if there is a very short audio segment left to process, we remove any past prompt since it tends // to confuse the decoder and often make it repeat or hallucinate stuff if (seek > seek_start && seek + 500 >= seek_end) { - prompt_past.clear(); + prompt_past1.clear(); } int best_decoder_id = 0; @@ -7050,12 +7063,44 @@ int whisper_full_with_state( { prompt.clear(); - // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + if (params.n_max_text_ctx > 0 && + t_cur < WHISPER_HISTORY_CONDITIONING_TEMP_CUTOFF) { + + const bool have_dynamic = !prompt_past1.empty(); + const bool can_carry_static = params.carry_initial_prompt && !prompt_past0.empty() && seek != seek_start; + + int max_ctx_half = std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2); + if (max_ctx_half > 0 && (have_dynamic || can_carry_static)) { + // Always start with previous token marker to connect continuity + prompt.push_back(whisper_token_prev(ctx)); + + if (can_carry_static) { + // Budget includes the prev token; we already consumed 1 slot. + int budget = max_ctx_half; // total allowed (including prev) - prompt = { whisper_token_prev(ctx) }; - prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + // Take as many static tokens as fit (reserving at least the prev token already placed) + int take_static = std::min(budget - 1, (int) prompt_past0.size()); + if (take_static > 0) { + auto start0 = take_static < (int) prompt_past0.size() ? prompt_past0.end() - take_static : prompt_past0.begin(); + prompt.insert(prompt.end(), start0, prompt_past0.end()); + } + + // Remaining budget for dynamic tail + int remaining = budget - take_static; + if (remaining > 0) { + int take_dynamic = std::min(remaining, (int) prompt_past1.size()); + if (take_dynamic > 0) { + prompt.insert(prompt.end(), prompt_past1.end() - take_dynamic, prompt_past1.end()); + } + } + } else { + // Dynamic only path + int n_take = std::min(max_ctx_half, (int) prompt_past1.size()); + if (n_take > 0) { + prompt.insert(prompt.end(), prompt_past1.end() - n_take, prompt_past1.end()); + } + } + } } // init new transcription with sot, language (opt) and task tokens @@ -7537,14 +7582,16 @@ int whisper_full_with_state( //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); - // update prompt_past - prompt_past.clear(); - if (prompt.front() == whisper_token_prev(ctx)) { - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + // update prompt_past1 + prompt_past1.clear(); + if (!prompt.empty() && prompt.front() == whisper_token_prev(ctx)) { + prompt_past1.insert(prompt_past1.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); } - for (int i = 0; i < result_len && !is_no_speech; ++i) { - prompt_past.push_back(tokens_cur[i].id); + if (!is_no_speech) { + for (int i = 0; i < result_len; ++i) { + prompt_past1.push_back(tokens_cur[i].id); + } } if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {