Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bindings/go/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v)
}


// Set language id
func (p *Params) SetLanguage(lang int) error {
if lang == -1 {
Expand Down Expand Up @@ -146,6 +147,10 @@ func (p *Params) SetInitialPrompt(prompt string) {
p.initial_prompt = C.CString(prompt)
}

func (p *Params) SetCarryInitialPrompt(v bool) {
p.carry_initial_prompt = toBool(v)
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

Expand Down Expand Up @@ -199,6 +204,9 @@ func (p *Params) String() string {
if p.token_timestamps {
str += " token_timestamps"
}
if p.carry_initial_prompt {
str += " carry_initial_prompt"
}

return str + ">"
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ public void tdrzEnable(boolean enable) {
/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Always prepend initial_prompt for every decode chunk. */
public CBool carry_initial_prompt;

/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
Expand Down Expand Up @@ -336,8 +338,8 @@ protected List<String> getFieldOrder() {
"no_timestamps", "single_segment", "print_special",
"print_progress", "print_realtime", "print_timestamps",
"token_timestamps", "thold_pt", "thold_ptsum", "max_len",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "carry_initial_prompt",
"prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_nst", "temperature",
"max_initial_ts", "length_penalty", "temperature_inc",
Expand Down
69 changes: 46 additions & 23 deletions bindings/ruby/ext/ruby_whisper_params.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);

#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 37

extern VALUE cParams;
extern VALUE cVADParams;
Expand All @@ -46,6 +46,7 @@ static ID id_print_special;
static ID id_print_progress;
static ID id_print_realtime;
static ID id_print_timestamps;
static ID id_carry_initial_prompt;
static ID id_suppress_blank;
static ID id_suppress_nst;
static ID id_token_timestamps;
Expand Down Expand Up @@ -455,6 +456,26 @@ ruby_whisper_params_get_print_timestamps(VALUE self)
{
BOOL_PARAMS_GETTER(self, print_timestamps)
}

/*
* call-seq:
* carry_initial_prompt -> true or false
*/
static VALUE
ruby_whisper_params_get_carry_initial_prompt(VALUE self)
{
BOOL_PARAMS_GETTER(self, carry_initial_prompt)
}

/*
* call-seq:
* carry_initial_prompt = bool -> bool
*/
static VALUE
ruby_whisper_params_set_carry_initial_prompt(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, carry_initial_prompt, value)
}
/*
* call-seq:
* suppress_blank = force_suppress -> force_suppress
Expand Down Expand Up @@ -1168,6 +1189,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(max_len)
SET_PARAM_IF_SAME(split_on_word)
SET_PARAM_IF_SAME(initial_prompt)
SET_PARAM_IF_SAME(carry_initial_prompt)
SET_PARAM_IF_SAME(offset)
SET_PARAM_IF_SAME(duration)
SET_PARAM_IF_SAME(max_text_tokens)
Expand Down Expand Up @@ -1303,28 +1325,29 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(max_len, 11)
DEFINE_PARAM(split_on_word, 12)
DEFINE_PARAM(initial_prompt, 13)
DEFINE_PARAM(diarize, 14)
DEFINE_PARAM(offset, 15)
DEFINE_PARAM(duration, 16)
DEFINE_PARAM(max_text_tokens, 17)
DEFINE_PARAM(temperature, 18)
DEFINE_PARAM(max_initial_ts, 19)
DEFINE_PARAM(length_penalty, 20)
DEFINE_PARAM(temperature_inc, 21)
DEFINE_PARAM(entropy_thold, 22)
DEFINE_PARAM(logprob_thold, 23)
DEFINE_PARAM(no_speech_thold, 24)
DEFINE_PARAM(new_segment_callback, 25)
DEFINE_PARAM(new_segment_callback_user_data, 26)
DEFINE_PARAM(progress_callback, 27)
DEFINE_PARAM(progress_callback_user_data, 28)
DEFINE_PARAM(encoder_begin_callback, 29)
DEFINE_PARAM(encoder_begin_callback_user_data, 30)
DEFINE_PARAM(abort_callback, 31)
DEFINE_PARAM(abort_callback_user_data, 32)
DEFINE_PARAM(vad, 33)
DEFINE_PARAM(vad_model_path, 34)
DEFINE_PARAM(vad_params, 35)
DEFINE_PARAM(carry_initial_prompt, 14)
DEFINE_PARAM(diarize, 15)
DEFINE_PARAM(offset, 16)
DEFINE_PARAM(duration, 17)
DEFINE_PARAM(max_text_tokens, 18)
DEFINE_PARAM(temperature, 19)
DEFINE_PARAM(max_initial_ts, 20)
DEFINE_PARAM(length_penalty, 21)
DEFINE_PARAM(temperature_inc, 22)
DEFINE_PARAM(entropy_thold, 23)
DEFINE_PARAM(logprob_thold, 24)
DEFINE_PARAM(no_speech_thold, 25)
DEFINE_PARAM(new_segment_callback, 26)
DEFINE_PARAM(new_segment_callback_user_data, 27)
DEFINE_PARAM(progress_callback, 28)
DEFINE_PARAM(progress_callback_user_data, 29)
DEFINE_PARAM(encoder_begin_callback, 30)
DEFINE_PARAM(encoder_begin_callback_user_data, 31)
DEFINE_PARAM(abort_callback, 32)
DEFINE_PARAM(abort_callback_user_data, 33)
DEFINE_PARAM(vad, 34)
DEFINE_PARAM(vad_model_path, 35)
DEFINE_PARAM(vad_params, 36)

rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
Expand Down
3 changes: 3 additions & 0 deletions bindings/ruby/sig/whisper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ module Whisper
?max_len: Integer,
?split_on_word: boolish,
?initial_prompt: string | nil,
?carry_initial_prompt: boolish,
?diarize: boolish,
?offset: Integer,
?duration: Integer,
Expand Down Expand Up @@ -236,13 +237,15 @@ module Whisper
def split_on_word: () -> (true | false)

def initial_prompt=: (_ToS) -> _ToS
def carry_initial_prompt=: (boolish) -> boolish

# Tokens to provide to the whisper decoder as initial prompt
# these are prepended to any existing text context from a previous call
# use whisper_tokenize() to convert text to tokens.
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
#
def initial_prompt: () -> (String | nil)
def carry_initial_prompt: () -> (true | false)

def diarize=: (boolish) -> boolish

Expand Down
8 changes: 8 additions & 0 deletions bindings/ruby/test/test_params.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TestParams < TestBase
:max_len,
:split_on_word,
:initial_prompt,
:carry_initial_prompt,
:diarize,
:offset,
:duration,
Expand Down Expand Up @@ -119,6 +120,13 @@ def test_print_timestamps
assert [email protected]_timestamps
end

def test_carry_initial_prompt
@params.carry_initial_prompt = true
assert @params.carry_initial_prompt
@params.carry_initial_prompt = false
assert [email protected]_initial_prompt
end

def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
Expand Down
Loading