Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
647c5f6
Add Params#new_segment_callback= method
KitaitiMakoto Oct 17, 2024
81e6df3
Add tests for Params#new_segment_callback=
KitaitiMakoto Oct 17, 2024
71b65b0
Group tests for #transcribe
KitaitiMakoto Oct 17, 2024
e8243b7
Don't use static for thread-safety
KitaitiMakoto Oct 17, 2024
d0d55f5
Set new_segment_callback only when necessary
KitaitiMakoto Oct 17, 2024
6077d90
Remove redundant check
KitaitiMakoto Oct 17, 2024
8fdbb20
[skip ci] Add Ruby version README
KitaitiMakoto Oct 18, 2024
abbee84
Revert "Group tests for #transcribe"
KitaitiMakoto Oct 19, 2024
67b375a
Revert "Add tests for Params#new_segment_callback="
KitaitiMakoto Oct 19, 2024
050a116
Add test for Context#full_n_segments
KitaitiMakoto Oct 19, 2024
b152263
Add Context#full_n_segments
KitaitiMakoto Oct 19, 2024
59db172
Add tests for lang API
KitaitiMakoto Oct 19, 2024
207a3f1
Add lang API
KitaitiMakoto Oct 19, 2024
22035fb
Add tests for Context#full_lang_id API
KitaitiMakoto Oct 19, 2024
8799616
Add Context#full_lang_id
KitaitiMakoto Oct 19, 2024
eef03e4
Add abnormal test cases for lang
KitaitiMakoto Oct 19, 2024
3f2f232
Raise appropriate errors from lang APIs
KitaitiMakoto Oct 19, 2024
7144916
Add tests for Context#full_get_segment_t{0,1} API
KitaitiMakoto Oct 19, 2024
5ebd2b5
Add Context#full_get_segment_t{0,1}
KitaitiMakoto Oct 19, 2024
3951acc
Add tests for Context#full_get_segment_speaker_turn_next API
KitaitiMakoto Oct 19, 2024
9e04f7a
Add Context#full_get_segment_speaker_turn_next
KitaitiMakoto Oct 19, 2024
0672e6f
Add tests for Context#full_get_segment_text
KitaitiMakoto Oct 19, 2024
5e350b1
Add Context#full_get_setgment_text
KitaitiMakoto Oct 19, 2024
d3a5157
Add tests for Params#new_segment_callback=
KitaitiMakoto Oct 20, 2024
a71d12e
Run new segment callback
KitaitiMakoto Oct 20, 2024
a6028f3
Split tests to multiple files
KitaitiMakoto Oct 20, 2024
c20afc3
Use container struct for new segment callback
KitaitiMakoto Oct 20, 2024
1e72d62
Add tests for Params#new_segment_callback_user_data=
KitaitiMakoto Oct 20, 2024
0a9957d
Add Whisper::Params#new_user_callback_user_data=
KitaitiMakoto Oct 20, 2024
73934b5
Add GC-related test for new segment callback
KitaitiMakoto Oct 20, 2024
1a3ff7c
Protect new segment callback related structs from GC
KitaitiMakoto Oct 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bindings/ruby/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
README.md
LICENSE
pkg/
lib/whisper.*
63 changes: 63 additions & 0 deletions bindings/ruby/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
whispercpp
==========

![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)

Ruby bindings for [whisper.cpp][], an interface of automatic speech recognition model.

Installation
------------

Install the gem and add to the application's Gemfile by executing:

$ bundle add whispercpp

If bundler is not being used to manage dependencies, install the gem by executing:

$ gem install whispercpp

Usage
-----

NOTE: This gem is still in development. API is not stable for now.

```ruby
require "whisper"

whisper = Whisper::Context.new("path/to/model.bin")

params = Whisper::Params.new
params.language = "en"
params.offset = 10_000
params.duration = 60_000
params.max_text_tokens = 300
params.translate = true
params.print_timestamps = false
params.new_segment_callback = ->(output, t0, t1, index) {
puts "segment #{index}: #{t0}ms -> #{t1}ms: #{output}"
}

whisper.transcribe("path/to/audio.wav", params) do |whole_text|
puts whole_text
end

```

### Preparing model ###

Use script to download model file(s):

```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
sh ./models/download-ggml-model.sh base.en
```

There are some types of models. See [models][] page for details.

### Preparing audio file ###

Currently, whisper.cpp accepts only 16-bit WAV files.

[whisper.cpp]: https://github.com/ggerganov/whisper.cpp
[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models
131 changes: 131 additions & 0 deletions bindings/ruby/ext/ruby_whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,37 @@ VALUE mWhisper;
VALUE cContext;
VALUE cParams;

static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
return INT2NUM(whisper_lang_max_id());
}

static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
const char * lang_str = StringValueCStr(lang);
const int id = whisper_lang_id(lang_str);
if (-1 == id) {
rb_raise(rb_eArgError, "language not found: %s", lang_str);
}
return INT2NUM(id);
}

static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
const int lang_id = NUM2INT(id);
const char * str = whisper_lang_str(lang_id);
if (nullptr == str) {
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
}
return rb_str_new2(str);
}

static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
const int lang_id = NUM2INT(id);
const char * str_full = whisper_lang_str_full(lang_id);
if (nullptr == str_full) {
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
}
return rb_str_new2(str_full);
}

static void ruby_whisper_free(ruby_whisper *rw) {
if (rw->context) {
whisper_free(rw->context);
Expand All @@ -55,9 +86,12 @@ void rb_whisper_free(ruby_whisper *rw) {
}

void rb_whisper_params_mark(ruby_whisper_params *rwp) {
rb_gc_mark(rwp->new_segment_callback_user_data->user_data);
rb_gc_mark(rwp->new_segment_callback_user_data->callback);
}

void rb_whisper_params_free(ruby_whisper_params *rwp) {
// How to free user_data and callback only when not referred to by others?
ruby_whisper_params_free(rwp);
free(rwp);
}
Expand All @@ -71,8 +105,15 @@ static VALUE ruby_whisper_allocate(VALUE klass) {

static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
ruby_whisper_callback_user_data *new_segment_callback_user_data;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
new_segment_callback_user_data = ALLOC(ruby_whisper_callback_user_data);
new_segment_callback_user_data->context = nullptr;
new_segment_callback_user_data->user_data = Qnil;
new_segment_callback_user_data->callback = Qnil;
rwp->new_segment_callback_user_data = new_segment_callback_user_data;

return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}

Expand Down Expand Up @@ -206,6 +247,18 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}

if (!NIL_P(rwp->new_segment_callback_user_data->callback)) {
rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const ruby_whisper_callback_user_data *container = (ruby_whisper_callback_user_data *)user_data;

// Currently, doesn't support state because
// those require to resolve GC-related problems.
rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data);
};
rwp->new_segment_callback_user_data->context = &self;
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_user_data;
}

if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
Expand All @@ -223,6 +276,58 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
return self;
}

static VALUE ruby_whisper_full_n_segments(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_n_segments(rw->context));
}

static VALUE ruby_whisper_full_lang_id(VALUE self) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_lang_id(rw->context));
}

static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) {
const int c_i_segment = NUM2INT(i_segment);
if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
}
return c_i_segment;
}

static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
return INT2NUM(t0);
}

static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
return INT2NUM(t1);
}

static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
return speaker_turn_next ? Qtrue : Qfalse;
}

static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) {
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
return rb_str_new2(text);
}

/*
* params.language = "auto" | "en", etc...
*/
Expand Down Expand Up @@ -365,16 +470,39 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) {
rwp->params.n_max_text_ctx = NUM2INT(value);
return value;
}
static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_user_data->callback = value;
return value;
}
static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_user_data->user_data = value;
return value;
}

void Init_whisper() {
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);

rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);

rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);

rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);

rb_define_alloc_func(cParams, ruby_whisper_params_allocate);

Expand Down Expand Up @@ -412,6 +540,9 @@ void Init_whisper() {

rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0);
rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1);

rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
}
#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions bindings/ruby/ext/ruby_whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@

#include "whisper.h"

typedef struct {
VALUE *context;
VALUE user_data;
VALUE callback;
} ruby_whisper_callback_user_data;

typedef struct {
struct whisper_context *context;
} ruby_whisper;

typedef struct {
struct whisper_full_params params;
bool diarize;
ruby_whisper_callback_user_data *new_segment_callback_user_data;
} ruby_whisper_params;

#endif
2 changes: 0 additions & 2 deletions bindings/ruby/extsources.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,4 @@
../../examples:
- ext/dr_wav.h
../..:
- README.md
- LICENSE

76 changes: 76 additions & 0 deletions bindings/ruby/tests/test_callback.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
require "test/unit"
require "whisper"

class TestCallback < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))

def setup
@params = Whisper::Params.new
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
end

def test_new_segment_callback
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_kind_of Integer, n_new
assert n_new > 0
assert_same @whisper, context

n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
start_time = context.full_get_segment_t0(i_segment) * 10
end_time = context.full_get_segment_t1(i_segment) * 10
text = context.full_get_segment_text(i_segment)

assert_kind_of Integer, start_time
assert start_time >= 0
assert_kind_of Integer, end_time
assert end_time > 0
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
end
}

@whisper.transcribe(@audio, @params)
end

def test_new_segment_callback_closure
search_word = "what"
@params.new_segment_callback = ->(context, state, n_new, user_data) {
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
text = context.full_get_segment_text(i_segment)
if text.include?(search_word)
t0 = context.full_get_segment_t0(i_segment)
t1 = context.full_get_segment_t1(i_segment)
raise "search word '#{search_word}' found at between #{t0} and #{t1}"
end
end
}

assert_raise RuntimeError do
@whisper.transcribe(@audio, @params)
end
end

def test_new_segment_callback_user_data
udata = Object.new
@params.new_segment_callback_user_data = udata
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_same udata, user_data
}

@whisper.transcribe(@audio, @params)
end

def test_new_segment_callback_user_data_gc
@params.new_segment_callback_user_data = "My user data"
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_equal "My user data", user_data
}
GC.start

assert_same @whisper, @whisper.transcribe(@audio, @params)
end
end
28 changes: 28 additions & 0 deletions bindings/ruby/tests/test_package.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
require 'test/unit'
require 'tempfile'
require 'tmpdir'
require 'shellwords'

class TestPackage < Test::Unit::TestCase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert_path_exist file.to_path
end
end

sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end

def test_install
filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename)
end
end
end
end
Loading
Loading