Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bindings/ruby/ext/ruby_whisper.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ID id_URI;
ID id_pre_converted_models;
ID id_coreml_compiled_models;
ID id_cache;
ID id_parallel;

static bool is_log_callback_finalized = false;

Expand Down Expand Up @@ -142,6 +143,7 @@ void Init_whisper() {
id_pre_converted_models = rb_intern("pre_converted_models");
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
id_cache = rb_intern("cache");
id_parallel = rb_intern("parallel");

mWhisper = rb_define_module("Whisper");
mVAD = rb_define_module_under(mWhisper, "VAD");
Expand Down
5 changes: 5 additions & 0 deletions bindings/ruby/ext/ruby_whisper_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern ID id_URI;
extern ID id_pre_converted_models;
extern ID id_coreml_compiled_models;
extern ID id_cache;
extern ID id_parallel;

extern VALUE cContext;
extern VALUE eError;
Expand All @@ -24,6 +25,8 @@ extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);

ID transcribe_option_names[1];

static void
ruby_whisper_free(ruby_whisper *rw)
{
Expand Down Expand Up @@ -633,6 +636,8 @@ init_ruby_whisper_context(VALUE *mWhisper)
{
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);

transcribe_option_names[0] = id_parallel;

rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);

Expand Down
19 changes: 15 additions & 4 deletions bindings/ruby/ext/ruby_whisper_transcribe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern const rb_data_type_t ruby_whisper_params_type;

extern ID id_to_s;
extern ID id_call;
extern ID transcribe_option_names[1];

extern void
prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
Expand All @@ -34,9 +35,14 @@ VALUE
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params;
VALUE wave_file_path, blk, params, kws;
VALUE opts[1];

rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, &params, &kws, &blk);
rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts);

bool parallel = !(NIL_P(opts[0]) || opts[0] == Qfalse);

rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);

Expand Down Expand Up @@ -66,8 +72,13 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {

prepare_transcription(rwp, &self);

if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
int result;
if (parallel) {
result = whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1);
} else {
result = whisper_full(rw->context, rwp->params, pcmf32.data(), pcmf32.size());
}
if (result != 0) {
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
Expand Down
68 changes: 34 additions & 34 deletions bindings/ruby/sig/whisper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ module Whisper
# transcribe a single file
# can emit to a block results
#
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
#
def transcribe: (string, Params) -> self
| (string, Params) { (String) -> void } -> self
Expand All @@ -50,16 +50,16 @@ module Whisper

# Yields each Whisper::Segment:
#
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
#
# Returns an Enumerator if no block given:
#
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
Expand All @@ -74,25 +74,25 @@ module Whisper

# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer) -> Integer

# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer) -> Integer

# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
#
# full_get_segment_speacker_turn_next(3) # => true
# full_get_segment_speacker_turn_next(3) # => true
#
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)

# Text of a segment indexed by +segment_index+.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer) -> String

Expand Down Expand Up @@ -282,9 +282,9 @@ module Whisper

# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback: () -> (new_segment_callback | nil)
Expand All @@ -297,9 +297,9 @@ module Whisper

# Sets progress callback, called on each progress update.
#
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
#
# +progress+ is an Integer between 0 and 100.
#
Expand Down Expand Up @@ -327,9 +327,9 @@ module Whisper

# Sets abort callback, called to check if the process should be aborted.
#
# params.abort_callback = ->(user_data) {
# # ...
# }
# params.abort_callback = ->(user_data) {
# # ...
# }
#
#
def abort_callback=: (abort_callback) -> abort_callback
Expand Down Expand Up @@ -358,9 +358,9 @@ module Whisper

# Hook called on new segment. Yields each Whisper::Segment.
#
# whisper.on_new_segment do |segment|
# # ...
# end
# whisper.on_new_segment do |segment|
# # ...
# end
#
def on_new_segment: { (Segment) -> void } -> void

Expand All @@ -374,13 +374,13 @@ module Whisper

# Call block to determine whether abort or not. Return +true+ when you want to abort.
#
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# end
# end
# end
#
def abort_on: { (Object user_data) -> boolish } -> void
end
Expand Down