Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bindings/ruby/ext/ruby_whisper.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ID id_URI;
ID id_pre_converted_models;
ID id_coreml_compiled_models;
ID id_cache;
ID id_parallel;

static bool is_log_callback_finalized = false;

Expand Down Expand Up @@ -142,6 +143,7 @@ void Init_whisper() {
id_pre_converted_models = rb_intern("pre_converted_models");
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
id_cache = rb_intern("cache");
id_parallel = rb_intern("parallel");

mWhisper = rb_define_module("Whisper");
mVAD = rb_define_module_under(mWhisper, "VAD");
Expand Down
5 changes: 5 additions & 0 deletions bindings/ruby/ext/ruby_whisper_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern ID id_URI;
extern ID id_pre_converted_models;
extern ID id_coreml_compiled_models;
extern ID id_cache;
extern ID id_parallel;

extern VALUE cContext;
extern VALUE eError;
Expand All @@ -24,6 +25,8 @@ extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);

ID transcribe_option_names[1];

static void
ruby_whisper_free(ruby_whisper *rw)
{
Expand Down Expand Up @@ -633,6 +636,8 @@ init_ruby_whisper_context(VALUE *mWhisper)
{
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);

transcribe_option_names[0] = id_parallel;

rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);

Expand Down
19 changes: 15 additions & 4 deletions bindings/ruby/ext/ruby_whisper_transcribe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern const rb_data_type_t ruby_whisper_params_type;

extern ID id_to_s;
extern ID id_call;
extern ID transcribe_option_names[1];

extern void
prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
Expand All @@ -34,9 +35,14 @@ VALUE
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params;
VALUE wave_file_path, blk, params, kws;
VALUE opts[1];

rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, &params, &kws, &blk);
rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts);

bool parallel = !(NIL_P(opts[0]) || opts[0] == Qfalse);

rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);

Expand Down Expand Up @@ -66,8 +72,13 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {

prepare_transcription(rwp, &self);

if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
int result;
if (parallel) {
result = whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1);
} else {
result = whisper_full(rw->context, rwp->params, pcmf32.data(), pcmf32.size());
}
if (result != 0) {
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
Expand Down