Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bindings/ruby/ext/options.rb
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def configure
bool "GGML_HIP_GRAPHS"
bool "GGML_HIP_NO_VMM"
bool "GGML_HIP_ROCWMMA_FATTN"
bool "GGML_HIP_UMA"
ignored "GGML_INCLUDE_INSTALL_DIR"
bool "GGML_KOMPUTE"
bool "GGML_LASX"
Expand Down
1 change: 1 addition & 0 deletions bindings/ruby/ext/ruby_whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ typedef struct {
bool diarize;
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_params;

Expand Down
119 changes: 116 additions & 3 deletions bindings/ruby/ext/ruby_whisper_params.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);

#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32

extern VALUE cParams;

Expand Down Expand Up @@ -63,6 +63,8 @@ static ID id_new_segment_callback;
static ID id_new_segment_callback_user_data;
static ID id_progress_callback;
static ID id_progress_callback_user_data;
static ID id_encoder_begin_callback;
static ID id_encoder_begin_callback_user_data;
static ID id_abort_callback;
static ID id_abort_callback_user_data;

Expand Down Expand Up @@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
}
}

static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
bool is_aborted = false;
VALUE result;

// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
is_aborted = true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return !is_aborted;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
is_aborted = true;
}
}
return !is_aborted;
}

static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
Expand Down Expand Up @@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}

if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
rwp->encoder_begin_callback_container->context = context;
rwp->params.encoder_begin_callback = encoder_begin_callback;
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}

if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback;
Expand All @@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
{
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
}

Expand All @@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass)
rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
Expand Down Expand Up @@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
rwp->progress_callback_container->user_data = value;
return value;
}

static VALUE
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->callback;
}

/*
* Sets encoder begin callback, called when the encoder starts.
*
* params.encoder_begin_callback = ->(context, _, user_data) {
* # ...
* }
*
* call-seq:
* encoder_begin_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->callback = value;
return value;
}

static VALUE
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->user_data;
}

/*
* Sets user data passed to the last argument of encoder begin callback.
*
* call-seq:
* encoder_begin_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->user_data = value;
return value;
}

static VALUE
ruby_whisper_params_get_abort_callback(VALUE self)
{
Expand Down Expand Up @@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(new_segment_callback_user_data)
SET_PARAM_IF_SAME(progress_callback)
SET_PARAM_IF_SAME(progress_callback_user_data)
SET_PARAM_IF_SAME(encoder_begin_callback)
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data)
}
Expand Down Expand Up @@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self)
return Qnil;
}

/*
* Hook called when the encoder starts.
*
* whisper.on_encoder_begin do
* # ...
* end
*
* call-seq:
* on_encoder_begin { ... }
*/
static VALUE
ruby_whisper_params_on_encoder_begin(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
return Qnil;
}

/*
* Call block to determine whether abort or not. Return +true+ when you want to abort.
*
Expand Down Expand Up @@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(abort_callback, 28)
DEFINE_PARAM(abort_callback_user_data, 29)
DEFINE_PARAM(encoder_begin_callback, 28)
DEFINE_PARAM(encoder_begin_callback_user_data, 29)
DEFINE_PARAM(abort_callback, 30)
DEFINE_PARAM(abort_callback_user_data, 31)

rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
}
17 changes: 9 additions & 8 deletions bindings/ruby/ext/ruby_whisper_transcribe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// Commented out because it is work in progress
// {
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race

rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data;
// return !is_aborted;
// };
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }

register_callbacks(rwp, &self);

Expand Down
4 changes: 2 additions & 2 deletions bindings/ruby/lib/whisper/model/uri.rb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def request(uri, headers)
http.request request do |response|
case response
when Net::HTTPNotModified
# noop
# noop
when Net::HTTPOK
download response
when Net::HTTPRedirection
Expand All @@ -68,7 +68,7 @@ def request(uri, headers)
rescue => err
if cache_path.exist?
warn err
# Use cache file
# Use cache file
else
raise
end
Expand Down
19 changes: 19 additions & 0 deletions bindings/ruby/sig/whisper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Whisper
type log_callback = ^(Integer level, String message, Object user_data) -> void
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish

LOG_LEVEL_NONE: Integer
Expand Down Expand Up @@ -146,6 +147,8 @@ module Whisper
?new_segment_callback_user_data: Object,
?progress_callback: progress_callback,
?progress_callback_user_data: Object,
?encoder_begin_callback: encoder_begin_callback,
?encoder_begin_callback_user_data: Object,
?abort_callback: abort_callback,
?abort_callback_user_data: Object
) -> instance
Expand Down Expand Up @@ -306,6 +309,18 @@ module Whisper

def progress_callback_user_data: () -> Object

# Sets encoder begin callback, called when the encoder starts.
#
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback

def encoder_begin_callback: () -> (encoder_begin_callback | nil)

# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object) -> Object

def encoder_begin_callback_user_data: () -> Object

# Sets abort callback, called to check if the process should be aborted.
#
# params.abort_callback = ->(user_data) {
Expand Down Expand Up @@ -335,6 +350,10 @@ module Whisper
#
def on_progress: { (Integer progress) -> void } -> void

# Hook called on encoder starts.
#
def on_encoder_begin: { () -> void } -> void

# Call block to determine whether abort or not. Return +true+ when you want to abort.
#
# params.abort_on do
Expand Down
4 changes: 2 additions & 2 deletions bindings/ruby/tests/helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")

class << self
attr_reader :whisper
def whisper
return @whisper if @whisper

def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
Expand Down
42 changes: 42 additions & 0 deletions bindings/ruby/tests/test_callback.rb
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,48 @@ def test_on_progress
assert_equal 100, last
end

def test_encoder_begin_callback
i = 0
@params.encoder_begin_callback = ->(context, state, user_data) {
i += 1
}
@whisper.transcribe(@audio, @params)
assert i > 0
end

def test_encoder_begin_callback_abort
logs = []
Whisper.log_set -> (level, buffer, user_data) {
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
}, logs
@params.encoder_begin_callback = ->(context, state, user_data) {
return false
}
@whisper.transcribe(@audio, @params)
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end

def test_encoder_begin_callback_user_data
udata = Object.new
@params.encoder_begin_callback_user_data = udata
yielded = nil
@params.encoder_begin_callback = ->(context, state, user_data) {
yielded = user_data
}
@whisper.transcribe(@audio, @params)
assert_same udata, yielded
end

def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@whisper.transcribe(@audio, @params)
assert i > 0
end

def test_abort_callback
i = 0
@params.abort_callback = ->(user_data) {
Expand Down
2 changes: 1 addition & 1 deletion bindings/ruby/whispercpp.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Gem::Specification.new do |s|
s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.2'
s.date = '2025-04-17'
s.date = '2025-04-25'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = '[email protected]'
s.extra_rdoc_files = ['LICENSE', 'README.md']
Expand Down
Loading