From 246793c7ab518f360dea4d5438478f048145645d Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 01:23:12 +0900 Subject: [PATCH 01/45] Add Params#new_segment_callback= method --- bindings/ruby/ext/ruby_whisper.cpp | 31 ++++++++++++++++++++++++++++++ bindings/ruby/ext/ruby_whisper.h | 1 + 2 files changed, 32 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9d9334539b8..96c435209a6 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -73,6 +73,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + rwp->new_segment_callback = Qnil; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -205,6 +206,28 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { }; rwp->params.encoder_begin_callback_user_data = &is_aborted; } + { + // This cannot be used later because it is not incremented when new_segment_callback is not given. + static int n_segments = 0; + + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } + + for (int i = 0; i < n_new; i++) { + const int i_segment = n_segments + i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + n_segments += n_new; + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); @@ -365,6 +388,12 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { rwp->params.n_max_text_ctx = NUM2INT(value); return value; } +static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback = value; + return value; +} void Init_whisper() { mWhisper = rb_define_module("Whisper"); @@ -412,6 +441,8 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); + + rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); } #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8c35b7cb65c..988750a8268 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -10,6 +10,7 @@ typedef struct { typedef struct { struct whisper_full_params params; bool diarize; + VALUE new_segment_callback; } ruby_whisper_params; #endif From 2ccddb957a220f9acb1dc649151d69e2274da203 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 01:23:43 +0900 Subject: [PATCH 02/45] Add tests for Params#new_segment_callback= --- bindings/ruby/tests/test_whisper.rb | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 410b5248a89..a496b3ae58f 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,6 +127,29 @@ def test_whisper } end + def test_new_segment_callback_lambda + counter = 0 + @params.new_segment_callback = ->(text, start_time, end_time, index) { + assert_kind_of String, text + assert_kind_of Integer, start_time + assert_kind_of Integer, end_time + assert_same index, counter + counter += 1 + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + + def test_new_segment_callback_proc + @params.new_segment_callback = proc {|text| # proc checks arguments loosly + assert_kind_of String, text + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From 0c1ff901404afec577aad05511eea3870998a6a4 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 01:29:50 +0900 Subject: [PATCH 03/45] Group tests for #transcribe --- bindings/ruby/tests/test_whisper.rb | 64 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index a496b3ae58f..0095ea20725 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -116,38 +116,38 @@ def test_split_on_word assert !@params.split_on_word end - def test_whisper - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - params = Whisper::Params.new - params.print_timestamps = false - - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - @whisper.transcribe(jfk, params) {|text| - assert_match /ask not what your country can do for you, ask what you can do for your country/, text - } - end - - def test_new_segment_callback_lambda - counter = 0 - @params.new_segment_callback = ->(text, start_time, end_time, index) { - assert_kind_of String, text - assert_kind_of Integer, start_time - assert_kind_of Integer, end_time - assert_same index, counter - counter += 1 - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - - def test_new_segment_callback_proc - @params.new_segment_callback = proc {|text| # proc checks arguments loosly - assert_kind_of String, text - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) + sub_test_case "#transcribe" do + def setup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @params = Whisper::Params.new + @params.print_timestamps = false + @jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + end + + def test_whisper + @whisper.transcribe(@jfk, @params) {|text| + assert_match /ask not what your country can do for you, ask what you can do for your country/, text + } + end + + def test_new_segment_callback_lambda + counter = 0 + @params.new_segment_callback = ->(text, start_time, end_time, index) { + assert_kind_of String, text + assert_kind_of Integer, start_time + assert_kind_of Integer, end_time + assert_same index, counter + counter += 1 + } + @whisper.transcribe(@jfk, @params) + end + + def test_new_segment_callback_proc + @params.new_segment_callback = proc {|text| # proc checks arguments loosly + assert_kind_of String, text + } + @whisper.transcribe(@jfk, @params) + end end def test_build From 1d2d772bc5ca3d0f52a1be11409c4f0fa5614f07 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 03:09:30 +0900 Subject: [PATCH 04/45] Don't use static for thread-safety --- bindings/ruby/ext/ruby_whisper.cpp | 36 +++++++++++++----------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 96c435209a6..b16c67e00ef 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -206,28 +206,24 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { }; rwp->params.encoder_begin_callback_user_data = &is_aborted; } - { - // This cannot be used later because it is not incremented when new_segment_callback is not given. - static int n_segments = 0; - rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - VALUE callback = *(VALUE *)user_data; - if (NIL_P(callback)){ - return; - } + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } - for (int i = 0; i < n_new; i++) { - const int i_segment = n_segments + i; - const char * text = whisper_full_get_segment_text_from_state(state, i_segment); - // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it - const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; - const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; - rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); - } - n_segments += n_new; - }; - rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; - } + int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; --i) { + const int i_segment = n_segments - i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); From 37aa3c691e7b3fd14961ed9d675a0f4a0b6f04ed Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 03:11:34 +0900 Subject: [PATCH 05/45] Set new_segment_callback only when necessary --- bindings/ruby/ext/ruby_whisper.cpp | 34 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index b16c67e00ef..77028939302 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -207,23 +207,25 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - VALUE callback = *(VALUE *)user_data; - if (NIL_P(callback)){ - return; - } + if (!NIL_P(rwp->new_segment_callback)) { + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } - int n_segments = whisper_full_n_segments_from_state(state); - for (int i = n_new; i > 0; --i) { - const int i_segment = n_segments - i; - const char * text = whisper_full_get_segment_text_from_state(state, i_segment); - // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it - const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; - const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; - rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); - } - }; - rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; --i) { + const int i_segment = n_segments - i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); From 3fec5d35a95dbea8ab56fbeea49b079952d4baa7 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 03:15:38 +0900 Subject: [PATCH 06/45] Remove redundant check --- bindings/ruby/ext/ruby_whisper.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 77028939302..1ee2453867d 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -210,10 +210,6 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { if (!NIL_P(rwp->new_segment_callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { VALUE callback = *(VALUE *)user_data; - if (NIL_P(callback)){ - return; - } - int n_segments = whisper_full_n_segments_from_state(state); for (int i = n_new; i > 0; --i) { const int i_segment = n_segments - i; From be67fa25cd4af4d7e5f98d90cc6dfbce9edd0260 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Fri, 18 Oct 2024 12:12:09 +0900 Subject: [PATCH 07/45] [skip ci] Add Ruby version README --- bindings/ruby/.gitignore | 1 - bindings/ruby/README.md | 63 +++++++++++++++++++++++++++++++++++ bindings/ruby/extsources.yaml | 1 - 3 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 bindings/ruby/README.md diff --git a/bindings/ruby/.gitignore b/bindings/ruby/.gitignore index 6ff6e5f2119..e04a90a9c69 100644 --- a/bindings/ruby/.gitignore +++ b/bindings/ruby/.gitignore @@ -1,4 +1,3 @@ -README.md LICENSE pkg/ lib/whisper.* diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md new file mode 100644 index 00000000000..4c6d0e86587 --- /dev/null +++ b/bindings/ruby/README.md @@ -0,0 +1,63 @@ +whispercpp +========== + +![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg) + +Ruby bindings for [whisper.cpp][], an interface of automatic speech recognition model. + +Installation +------------ + +Install the gem and add to the application's Gemfile by executing: + + $ bundle add whispercpp + +If bundler is not being used to manage dependencies, install the gem by executing: + + $ gem install whispercpp + +Usage +----- + +NOTE: This gem is still in development. API is not stable for now. + +```ruby +require "whisper" + +whisper = Whisper::Context.new("path/to/model.bin") + +params = Whisper::Params.new +params.language = "en" +params.offset = 10_000 +params.duration = 60_000 +params.max_text_tokens = 300 +params.translate = true +params.print_timestamps = false +params.new_segment_callback = ->(output, t0, t1, index) { + puts "segment #{index}: #{t0}ms -> #{t1}ms: #{output}" +} + +whisper.transcribe("path/to/audio.wav", params) do |whole_text| + puts whole_text +end + +``` + +### Preparing model ### + +Use script to download model file(s): + +```bash +git clone https://github.com/ggerganov/whisper.cpp.git +cd whisper.cpp +sh ./models/download-ggml-model.sh base.en +``` + +There are some types of models. See [models][] page for details. + +### Preparing audio file ### + +Currently, whisper.cpp accepts only 16-bit WAV files. + +[whisper.cpp]: https://github.com/ggerganov/whisper.cpp +[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models diff --git a/bindings/ruby/extsources.yaml b/bindings/ruby/extsources.yaml index e59f6ecf0fb..52f0330d88d 100644 --- a/bindings/ruby/extsources.yaml +++ b/bindings/ruby/extsources.yaml @@ -33,5 +33,4 @@ ../../examples: - ext/dr_wav.h ../..: -- README.md - LICENSE From 58d3fb66570ea3065a91c9f5ccdadd0ba54fc551 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 02:36:12 +0900 Subject: [PATCH 08/45] Revert "Group tests for #transcribe" This reverts commit 71b65b00ccf1816c9ea8a247fb30f71bc09707d3. --- bindings/ruby/tests/test_whisper.rb | 64 ++++++++++++++--------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 0095ea20725..a496b3ae58f 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -116,38 +116,38 @@ def test_split_on_word assert !@params.split_on_word end - sub_test_case "#transcribe" do - def setup - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - @params = Whisper::Params.new - @params.print_timestamps = false - @jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - end - - def test_whisper - @whisper.transcribe(@jfk, @params) {|text| - assert_match /ask not what your country can do for you, ask what you can do for your country/, text - } - end - - def test_new_segment_callback_lambda - counter = 0 - @params.new_segment_callback = ->(text, start_time, end_time, index) { - assert_kind_of String, text - assert_kind_of Integer, start_time - assert_kind_of Integer, end_time - assert_same index, counter - counter += 1 - } - @whisper.transcribe(@jfk, @params) - end - - def test_new_segment_callback_proc - @params.new_segment_callback = proc {|text| # proc checks arguments loosly - assert_kind_of String, text - } - @whisper.transcribe(@jfk, @params) - end + def test_whisper + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) {|text| + assert_match /ask not what your country can do for you, ask what you can do for your country/, text + } + end + + def test_new_segment_callback_lambda + counter = 0 + @params.new_segment_callback = ->(text, start_time, end_time, index) { + assert_kind_of String, text + assert_kind_of Integer, start_time + assert_kind_of Integer, end_time + assert_same index, counter + counter += 1 + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + + def test_new_segment_callback_proc + @params.new_segment_callback = proc {|text| # proc checks arguments loosly + assert_kind_of String, text + } + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) end def test_build From cee8c6af5b7b906feb1b568447de28a3d67db04a Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 02:36:26 +0900 Subject: [PATCH 09/45] Revert "Add tests for Params#new_segment_callback=" This reverts commit 81e6df3bab7662da5379db51f28a989db7408c02. --- bindings/ruby/tests/test_whisper.rb | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index a496b3ae58f..410b5248a89 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,29 +127,6 @@ def test_whisper } end - def test_new_segment_callback_lambda - counter = 0 - @params.new_segment_callback = ->(text, start_time, end_time, index) { - assert_kind_of String, text - assert_kind_of Integer, start_time - assert_kind_of Integer, end_time - assert_same index, counter - counter += 1 - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - - def test_new_segment_callback_proc - @params.new_segment_callback = proc {|text| # proc checks arguments loosly - assert_kind_of String, text - } - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From b589895e9642d58094aeb6b6d80c680a082900fa Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 04:16:23 +0900 Subject: [PATCH 10/45] Add test for Context#full_n_segments --- bindings/ruby/tests/test_whisper.rb | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 410b5248a89..09da562547c 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,6 +127,28 @@ def test_whisper } end + sub_test_case "After transcription" do + class << self + attr_reader :whisper + + def startup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) + end + end + + def whisper + self.class.whisper + end + + def test_full_n_segments + assert_equal 1, whisper.full_n_segments + end + end + def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From 8990107e7abc69f16472e31ed498f878d4e4fb67 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 04:16:39 +0900 Subject: [PATCH 11/45] Add Context#full_n_segments --- bindings/ruby/ext/ruby_whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 1ee2453867d..5e912cb9180 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -240,6 +240,12 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { return self; } +static VALUE ruby_whisper_full_n_segments(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_n_segments(rw->context)); +} + /* * params.language = "auto" | "en", etc... */ @@ -398,6 +404,7 @@ void Init_whisper() { rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From a175c5c6ec1648893c2d130f45d395b19c2c50a0 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:15:31 +0900 Subject: [PATCH 12/45] Add tests for lang API --- bindings/ruby/tests/test_whisper.rb | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 09da562547c..b2408ca5a32 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -149,6 +149,22 @@ def test_full_n_segments end end + def test_lang_max_id + assert_kind_of Integer, Whisper.lang_max_id + end + + def test_lang_id + assert_equal 0, Whisper.lang_id("en") + end + + def test_lang_str + assert_equal "en", Whisper.lang_str(0) + end + + def test_lang_str_full + assert_equal "english", Whisper.lang_str_full(0) + end + def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) From 1c74fdc0676a2287fde4f2aa4b24b1f907c6ba88 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:15:48 +0900 Subject: [PATCH 13/45] Add lang API --- bindings/ruby/ext/ruby_whisper.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 5e912cb9180..67669e64682 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -36,6 +36,22 @@ VALUE mWhisper; VALUE cContext; VALUE cParams; +static VALUE ruby_whisper_s_lang_max_id(VALUE self) { + return INT2NUM(whisper_lang_max_id()); +} + +static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { + return INT2NUM(whisper_lang_id(StringValueCStr(lang))); +} + +static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { + return rb_str_new2(whisper_lang_str(NUM2INT(id))); +} + +static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { + return rb_str_new2(whisper_lang_str_full(NUM2INT(id))); +} + static void ruby_whisper_free(ruby_whisper *rw) { if (rw->context) { whisper_free(rw->context); @@ -400,6 +416,11 @@ void Init_whisper() { cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); + rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); + rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); + rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); + rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); From 588aa1ca85af7965f2706539fe38d1e37ca6b145 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:29:09 +0900 Subject: [PATCH 14/45] Add tests for Context#full_lang_id API --- bindings/ruby/tests/test_whisper.rb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index b2408ca5a32..69324b3d160 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -147,6 +147,10 @@ def whisper def test_full_n_segments assert_equal 1, whisper.full_n_segments end + + def test_full_lang_id + assert_equal 0, whisper.full_lang_id + end end def test_lang_max_id From 6e62076380bfd7097cfab252ced2280611f79561 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 05:29:19 +0900 Subject: [PATCH 15/45] Add Context#full_lang_id --- bindings/ruby/ext/ruby_whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 67669e64682..0ec0b646ef6 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -262,6 +262,12 @@ static VALUE ruby_whisper_full_n_segments(VALUE self) { return INT2NUM(whisper_full_n_segments(rw->context)); } +static VALUE ruby_whisper_full_lang_id(VALUE self) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_lang_id(rw->context)); +} + /* * params.language = "auto" | "en", etc... */ @@ -426,6 +432,7 @@ void Init_whisper() { rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); + rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From ad55836c3ee70bd05eefc2e90e0a2655d39f09a1 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:01:38 +0900 Subject: [PATCH 16/45] Add abnormal test cases for lang --- bindings/ruby/tests/test_whisper.rb | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 69324b3d160..7dfa067311c 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -159,14 +159,23 @@ def test_lang_max_id def test_lang_id assert_equal 0, Whisper.lang_id("en") + assert_raise ArgumentError do + Whisper.lang_id("non existing language") + end end def test_lang_str assert_equal "en", Whisper.lang_str(0) + assert_raise IndexError do + Whisper.lang_str(Whisper.lang_max_id + 1) + end end def test_lang_str_full assert_equal "english", Whisper.lang_str_full(0) + assert_raise IndexError do + Whisper.lang_str_full(Whisper.lang_max_id + 1) + end end def test_build From e0255a5a12e5a3a31b9676429489308b28b1fc25 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:01:52 +0900 Subject: [PATCH 17/45] Raise appropriate errors from lang APIs --- bindings/ruby/ext/ruby_whisper.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 0ec0b646ef6..db91cb6dbef 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -41,15 +41,30 @@ static VALUE ruby_whisper_s_lang_max_id(VALUE self) { } static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { - return INT2NUM(whisper_lang_id(StringValueCStr(lang))); + const char * lang_str = StringValueCStr(lang); + const int id = whisper_lang_id(lang_str); + if (-1 == id) { + rb_raise(rb_eArgError, "language not found: %s", lang_str); + } + return INT2NUM(id); } static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { - return rb_str_new2(whisper_lang_str(NUM2INT(id))); + const int lang_id = NUM2INT(id); + const char * str = whisper_lang_str(lang_id); + if (nullptr == str) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str); } static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { - return rb_str_new2(whisper_lang_str_full(NUM2INT(id))); + const int lang_id = NUM2INT(id); + const char * str_full = whisper_lang_str_full(lang_id); + if (nullptr == str_full) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str_full); } static void ruby_whisper_free(ruby_whisper *rw) { From 09eb66d84eca939fdd48d39cfaf225682caa8bd9 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:25:22 +0900 Subject: [PATCH 18/45] Add tests for Context#full_get_segment_t{0,1} API --- bindings/ruby/tests/test_whisper.rb | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 7dfa067311c..af2aca9fa49 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -151,6 +151,25 @@ def test_full_n_segments def test_full_lang_id assert_equal 0, whisper.full_lang_id end + + def test_full_get_segment_t0 + assert_equal 0, whisper.full_get_segment_t0(0) + assert_raise IndexError do + whisper.full_get_segment_t0(whisper.full_n_segments) + end + assert_raise IndexError do + whisper.full_get_segment_t0(-1) + end + end + + def test_full_get_segment_t1 + t1 = whisper.full_get_segment_t1(0) + assert_kind_of Integer, t1 + assert t1 > 0 + assert_raise IndexError do + whisper.full_get_segment_t1(whisper.full_n_segments) + end + end end def test_lang_max_id From 4f261f6c1a849fe26ca526f4d1f604f6d0f43ad5 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:25:35 +0900 Subject: [PATCH 19/45] Add Context#full_get_segment_t{0,1} --- bindings/ruby/ext/ruby_whisper.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index db91cb6dbef..dc78ff9d258 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -283,6 +283,30 @@ static VALUE ruby_whisper_full_lang_id(VALUE self) { return INT2NUM(whisper_full_lang_id(rw->context)); } +static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) { + const int c_i_segment = NUM2INT(i_segment); + if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) { + rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment); + } + return c_i_segment; +} + +static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); + return INT2NUM(t0); +} + +static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); + return INT2NUM(t1); +} + /* * params.language = "auto" | "en", etc... */ @@ -448,6 +472,8 @@ void Init_whisper() { rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); + rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); + rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From d69e0bed77fda92e47a3ccbecc31cb481b61b8ee Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:41:28 +0900 Subject: [PATCH 20/45] Add tests for Context#full_get_segment_speaker_turn_next API --- bindings/ruby/tests/test_whisper.rb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index af2aca9fa49..207e8127a4e 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -170,6 +170,10 @@ def test_full_get_segment_t1 whisper.full_get_segment_t1(whisper.full_n_segments) end end + + def test_full_get_segment_speaker_turn_next + assert_false whisper.full_get_segment_speaker_turn_next(0) + end end def test_lang_max_id From 9902dcc888269f9736a8bd904d7cd0fc5fed1e07 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:41:36 +0900 Subject: [PATCH 21/45] Add Context#full_get_segment_speaker_turn_next --- bindings/ruby/ext/ruby_whisper.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index dc78ff9d258..d8134ff0b2c 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -307,6 +307,14 @@ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { return INT2NUM(t1); } +static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); + return speaker_turn_next ? Qtrue : Qfalse; +} + /* * params.language = "auto" | "en", etc... */ @@ -474,6 +482,7 @@ void Init_whisper() { rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); + rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From beba53939498c06ac18eefdf3b42de70e583d150 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:44:35 +0900 Subject: [PATCH 22/45] Add tests for Context#full_get_segment_text --- bindings/ruby/tests/test_whisper.rb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 207e8127a4e..48b95af94e5 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -174,6 +174,10 @@ def test_full_get_segment_t1 def test_full_get_segment_speaker_turn_next assert_false whisper.full_get_segment_speaker_turn_next(0) end + + def test_full_get_segment_text + assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0) + end end def test_lang_max_id From 63830b6c9a8d4cc57237a25b0dec4ff2ee838719 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 06:44:48 +0900 Subject: [PATCH 23/45] Add Context#full_get_setgment_text --- bindings/ruby/ext/ruby_whisper.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index d8134ff0b2c..56abc6022f4 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -315,6 +315,14 @@ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i return speaker_turn_next ? Qtrue : Qfalse; } +static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); + return rb_str_new2(text); +} + /* * params.language = "auto" | "en", etc... */ @@ -483,6 +491,7 @@ void Init_whisper() { rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); + rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1); rb_define_alloc_func(cParams, ruby_whisper_params_allocate); From 0d1ec5f19b2fea4984fb096c406132085b992fa1 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 10:55:15 +0900 Subject: [PATCH 24/45] Add tests for Params#new_segment_callback= --- bindings/ruby/tests/test_whisper.rb | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 48b95af94e5..74e8b7fc7a7 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -127,6 +127,56 @@ def test_whisper } end + def test_new_segment_callback + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same whisper, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 + end + } + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + whisper.transcribe(jfk, @params) + end + + def test_new_segment_callback_closure + whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + + search_word = "what" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + text = context.full_get_segment_text(i_segment) + if text.include?(search_word) + t0 = context.full_get_segment_t0(i_segment) + t1 = context.full_get_segment_t1(i_segment) + raise "search word '#{search_word}' found at between #{t0} and #{t1}" + end + end + } + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + assert_raise RuntimeError do + whisper.transcribe(jfk, @params) + end + end + sub_test_case "After transcription" do class << self attr_reader :whisper From 084f450e2ac21b41497a491c8551e803ffa87e58 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 10:56:02 +0900 Subject: [PATCH 25/45] Run new segment callback --- bindings/ruby/ext/ruby_whisper.cpp | 21 ++++++++++----------- bindings/ruby/ext/ruby_whisper.h | 1 + 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 56abc6022f4..d0344d423d4 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -240,18 +240,17 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { if (!NIL_P(rwp->new_segment_callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - VALUE callback = *(VALUE *)user_data; - int n_segments = whisper_full_n_segments_from_state(state); - for (int i = n_new; i > 0; --i) { - const int i_segment = n_segments - i; - const char * text = whisper_full_get_segment_text_from_state(state, i_segment); - // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it - const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; - const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; - rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); - } + ruby_whisper *rw;; + VALUE context = *(VALUE *)user_data; + Data_Get_Struct(context, ruby_whisper, rw); + VALUE callback = rw->new_segment_callback; + + // Currently, doesn't support state and user_data because + // those require to resolve GC-related problems. + rb_funcall(callback, rb_intern("call"), 4, context, Qnil, INT2NUM(n_new), Qnil); }; - rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + rw->new_segment_callback = rwp->new_segment_callback; + rwp->params.new_segment_callback_user_data = &self; } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 988750a8268..1481bfa9e17 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,6 +5,7 @@ typedef struct { struct whisper_context *context; + VALUE new_segment_callback; } ruby_whisper; typedef struct { From 6128e05c71e4db3fd76db5b6377524abafaf3595 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 20:42:44 +0900 Subject: [PATCH 26/45] Split tests to multiple files --- bindings/ruby/tests/test_callback.rb | 56 ++++++++ bindings/ruby/tests/test_package.rb | 28 ++++ bindings/ruby/tests/test_params.rb | 112 ++++++++++++++++ bindings/ruby/tests/test_whisper.rb | 184 +-------------------------- 4 files changed, 198 insertions(+), 182 deletions(-) create mode 100644 bindings/ruby/tests/test_callback.rb create mode 100644 bindings/ruby/tests/test_package.rb create mode 100644 bindings/ruby/tests/test_params.rb diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb new file mode 100644 index 00000000000..644fc80d295 --- /dev/null +++ b/bindings/ruby/tests/test_callback.rb @@ -0,0 +1,56 @@ +require "test/unit" +require "whisper" + +class TestCallback < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + + def setup + @params = Whisper::Params.new + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @whisper, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 + end + } + + @whisper.transcribe(@audio, @params) + end + + def test_new_segment_callback_closure + search_word = "what" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + text = context.full_get_segment_text(i_segment) + if text.include?(search_word) + t0 = context.full_get_segment_t0(i_segment) + t1 = context.full_get_segment_t1(i_segment) + raise "search word '#{search_word}' found at between #{t0} and #{t1}" + end + end + } + + assert_raise RuntimeError do + @whisper.transcribe(@audio, @params) + end + end +end diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb new file mode 100644 index 00000000000..9d7527340f2 --- /dev/null +++ b/bindings/ruby/tests/test_package.rb @@ -0,0 +1,28 @@ +require 'test/unit' +require 'tempfile' +require 'tmpdir' +require 'shellwords' + +class TestPackage < Test::Unit::TestCase + def test_build + Tempfile.create do |file| + assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) + assert_path_exist file.to_path + end + end + + sub_test_case "Building binary on installation" do + def setup + system "rake", "build", exception: true + end + + def test_install + filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] + basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" + Dir.mktmpdir do |dir| + system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true + assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) + end + end + end +end diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb new file mode 100644 index 00000000000..4484feeeff1 --- /dev/null +++ b/bindings/ruby/tests/test_params.rb @@ -0,0 +1,112 @@ +require 'whisper' + +class TestParams < Test::Unit::TestCase + def setup + @params = Whisper::Params.new + end + + def test_language + @params.language = "en" + assert_equal @params.language, "en" + @params.language = "auto" + assert_equal @params.language, "auto" + end + + def test_offset + @params.offset = 10_000 + assert_equal @params.offset, 10_000 + @params.offset = 0 + assert_equal @params.offset, 0 + end + + def test_duration + @params.duration = 60_000 + assert_equal @params.duration, 60_000 + @params.duration = 0 + assert_equal @params.duration, 0 + end + + def test_max_text_tokens + @params.max_text_tokens = 300 + assert_equal @params.max_text_tokens, 300 + @params.max_text_tokens = 0 + assert_equal @params.max_text_tokens, 0 + end + + def test_translate + @params.translate = true + assert @params.translate + @params.translate = false + assert !@params.translate + end + + def test_no_context + @params.no_context = true + assert @params.no_context + @params.no_context = false + assert !@params.no_context + end + + def test_single_segment + @params.single_segment = true + assert @params.single_segment + @params.single_segment = false + assert !@params.single_segment + end + + def test_print_special + @params.print_special = true + assert @params.print_special + @params.print_special = false + assert !@params.print_special + end + + def test_print_progress + @params.print_progress = true + assert @params.print_progress + @params.print_progress = false + assert !@params.print_progress + end + + def test_print_realtime + @params.print_realtime = true + assert @params.print_realtime + @params.print_realtime = false + assert !@params.print_realtime + end + + def test_print_timestamps + @params.print_timestamps = true + assert @params.print_timestamps + @params.print_timestamps = false + assert !@params.print_timestamps + end + + def test_suppress_blank + @params.suppress_blank = true + assert @params.suppress_blank + @params.suppress_blank = false + assert !@params.suppress_blank + end + + def test_suppress_non_speech_tokens + @params.suppress_non_speech_tokens = true + assert @params.suppress_non_speech_tokens + @params.suppress_non_speech_tokens = false + assert !@params.suppress_non_speech_tokens + end + + def test_token_timestamps + @params.token_timestamps = true + assert @params.token_timestamps + @params.token_timestamps = false + assert !@params.token_timestamps + end + + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word + end +end diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 74e8b7fc7a7..5ebb8151c65 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -1,121 +1,13 @@ -TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) - require 'whisper' require 'test/unit' -require 'tempfile' -require 'tmpdir' -require 'shellwords' class TestWhisper < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + def setup @params = Whisper::Params.new end - def test_language - @params.language = "en" - assert_equal @params.language, "en" - @params.language = "auto" - assert_equal @params.language, "auto" - end - - def test_offset - @params.offset = 10_000 - assert_equal @params.offset, 10_000 - @params.offset = 0 - assert_equal @params.offset, 0 - end - - def test_duration - @params.duration = 60_000 - assert_equal @params.duration, 60_000 - @params.duration = 0 - assert_equal @params.duration, 0 - end - - def test_max_text_tokens - @params.max_text_tokens = 300 - assert_equal @params.max_text_tokens, 300 - @params.max_text_tokens = 0 - assert_equal @params.max_text_tokens, 0 - end - - def test_translate - @params.translate = true - assert @params.translate - @params.translate = false - assert !@params.translate - end - - def test_no_context - @params.no_context = true - assert @params.no_context - @params.no_context = false - assert !@params.no_context - end - - def test_single_segment - @params.single_segment = true - assert @params.single_segment - @params.single_segment = false - assert !@params.single_segment - end - - def test_print_special - @params.print_special = true - assert @params.print_special - @params.print_special = false - assert !@params.print_special - end - - def test_print_progress - @params.print_progress = true - assert @params.print_progress - @params.print_progress = false - assert !@params.print_progress - end - - def test_print_realtime - @params.print_realtime = true - assert @params.print_realtime - @params.print_realtime = false - assert !@params.print_realtime - end - - def test_print_timestamps - @params.print_timestamps = true - assert @params.print_timestamps - @params.print_timestamps = false - assert !@params.print_timestamps - end - - def test_suppress_blank - @params.suppress_blank = true - assert @params.suppress_blank - @params.suppress_blank = false - assert !@params.suppress_blank - end - - def test_suppress_non_speech_tokens - @params.suppress_non_speech_tokens = true - assert @params.suppress_non_speech_tokens - @params.suppress_non_speech_tokens = false - assert !@params.suppress_non_speech_tokens - end - - def test_token_timestamps - @params.token_timestamps = true - assert @params.token_timestamps - @params.token_timestamps = false - assert !@params.token_timestamps - end - - def test_split_on_word - @params.split_on_word = true - assert @params.split_on_word - @params.split_on_word = false - assert !@params.split_on_word - end - def test_whisper @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) params = Whisper::Params.new @@ -127,56 +19,6 @@ def test_whisper } end - def test_new_segment_callback - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - - @params.new_segment_callback = ->(context, state, n_new, user_data) { - assert_kind_of Integer, n_new - assert n_new > 0 - assert_same whisper, context - - n_segments = context.full_n_segments - n_new.times do |i| - i_segment = n_segments - 1 + i - start_time = context.full_get_segment_t0(i_segment) * 10 - end_time = context.full_get_segment_t1(i_segment) * 10 - text = context.full_get_segment_text(i_segment) - - assert_kind_of Integer, start_time - assert start_time >= 0 - assert_kind_of Integer, end_time - assert end_time > 0 - assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0 - end - } - - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - whisper.transcribe(jfk, @params) - end - - def test_new_segment_callback_closure - whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - - search_word = "what" - @params.new_segment_callback = ->(context, state, n_new, user_data) { - n_segments = context.full_n_segments - n_new.times do |i| - i_segment = n_segments - 1 + i - text = context.full_get_segment_text(i_segment) - if text.include?(search_word) - t0 = context.full_get_segment_t0(i_segment) - t1 = context.full_get_segment_t1(i_segment) - raise "search word '#{search_word}' found at between #{t0} and #{t1}" - end - end - } - - jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') - assert_raise RuntimeError do - whisper.transcribe(jfk, @params) - end - end - sub_test_case "After transcription" do class << self attr_reader :whisper @@ -254,26 +96,4 @@ def test_lang_str_full Whisper.lang_str_full(Whisper.lang_max_id + 1) end end - - def test_build - Tempfile.create do |file| - assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) - assert_path_exist file.to_path - end - end - - sub_test_case "Building binary on installation" do - def setup - system "rake", "build", exception: true - end - - def test_install - filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1] - basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}" - Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true - assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename) - end - end - end end From c2de24a3ab998dc1885fbd098e7fb29a64d5749b Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 21:43:58 +0900 Subject: [PATCH 27/45] Use container struct for new segment callback --- bindings/ruby/ext/ruby_whisper.cpp | 24 +++++++++++++----------- bindings/ruby/ext/ruby_whisper.h | 9 +++++++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index d0344d423d4..9ccdc2a6f7c 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -102,9 +102,14 @@ static VALUE ruby_whisper_allocate(VALUE klass) { static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; + ruby_whisper_callback_user_data *new_segment_callback_user_data; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - rwp->new_segment_callback = Qnil; + new_segment_callback_user_data = ALLOC(ruby_whisper_callback_user_data); + new_segment_callback_user_data->context = nullptr; + new_segment_callback_user_data->user_data = Qnil; + new_segment_callback_user_data->callback = Qnil; + rwp->new_segment_callback_user_data = new_segment_callback_user_data; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -238,19 +243,16 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - if (!NIL_P(rwp->new_segment_callback)) { + if (!NIL_P(rwp->new_segment_callback_user_data->callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - ruby_whisper *rw;; - VALUE context = *(VALUE *)user_data; - Data_Get_Struct(context, ruby_whisper, rw); - VALUE callback = rw->new_segment_callback; + const ruby_whisper_callback_user_data *container = (ruby_whisper_callback_user_data *)user_data; - // Currently, doesn't support state and user_data because + // Currently, doesn't support state because // those require to resolve GC-related problems. - rb_funcall(callback, rb_intern("call"), 4, context, Qnil, INT2NUM(n_new), Qnil); + rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); }; - rw->new_segment_callback = rwp->new_segment_callback; - rwp->params.new_segment_callback_user_data = &self; + rwp->new_segment_callback_user_data->context = &self; + rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_user_data; } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { @@ -467,7 +469,7 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->new_segment_callback = value; + rwp->new_segment_callback_user_data->callback = value; return value; } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 1481bfa9e17..033b780e94b 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -3,15 +3,20 @@ #include "whisper.h" +typedef struct { + VALUE *context; + VALUE user_data; + VALUE callback; +} ruby_whisper_callback_user_data; + typedef struct { struct whisper_context *context; - VALUE new_segment_callback; } ruby_whisper; typedef struct { struct whisper_full_params params; bool diarize; - VALUE new_segment_callback; + ruby_whisper_callback_user_data *new_segment_callback_user_data; } ruby_whisper_params; #endif From 3f2801346e76d5adeb2d572788fd7e3893618a46 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 21:54:04 +0900 Subject: [PATCH 28/45] Add tests for Params#new_segment_callback_user_data= --- bindings/ruby/tests/test_callback.rb | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 644fc80d295..5697079fbad 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -53,4 +53,14 @@ def test_new_segment_callback_closure @whisper.transcribe(@audio, @params) end end + + def test_new_segment_callback_user_data + udata = Object.new + @params.new_segment_callback_user_data = udata + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_same udata, user_data + } + + @whisper.transcribe(@audio, @params) + end end From bb4e81c9f2749fbc103c40b12fb38d12ebbec63e Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 21:54:36 +0900 Subject: [PATCH 29/45] Add Whisper::Params#new_user_callback_user_data= --- bindings/ruby/ext/ruby_whisper.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9ccdc2a6f7c..168cb98320d 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -472,6 +472,12 @@ static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE valu rwp->new_segment_callback_user_data->callback = value; return value; } +static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback_user_data->user_data = value; + return value; +} void Init_whisper() { mWhisper = rb_define_module("Whisper"); @@ -532,6 +538,7 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); + rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); } #ifdef __cplusplus } From f41150786d885fc8b5480c13550bc271ea7a1ded Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 23:20:30 +0900 Subject: [PATCH 30/45] Add GC-related test for new segment callback --- bindings/ruby/tests/test_callback.rb | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 5697079fbad..80a5f4dfae6 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -63,4 +63,14 @@ def test_new_segment_callback_user_data @whisper.transcribe(@audio, @params) end + + def test_new_segment_callback_user_data_gc + @params.new_segment_callback_user_data = "My user data" + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_equal "My user data", user_data + } + GC.start + + assert_same @whisper, @whisper.transcribe(@audio, @params) + end end From 79617d7bc6d0ef8551320fdc27a2d746d940a717 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Sun, 20 Oct 2024 23:20:48 +0900 Subject: [PATCH 31/45] Protect new segment callback related structs from GC --- bindings/ruby/ext/ruby_whisper.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 168cb98320d..8dd935dda08 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -86,9 +86,12 @@ void rb_whisper_free(ruby_whisper *rw) { } void rb_whisper_params_mark(ruby_whisper_params *rwp) { + rb_gc_mark(rwp->new_segment_callback_user_data->user_data); + rb_gc_mark(rwp->new_segment_callback_user_data->callback); } void rb_whisper_params_free(ruby_whisper_params *rwp) { + // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); free(rwp); } @@ -110,6 +113,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { new_segment_callback_user_data->user_data = Qnil; new_segment_callback_user_data->callback = Qnil; rwp->new_segment_callback_user_data = new_segment_callback_user_data; + return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } From eae174eef1214edee1a5dd3d7e833b17d5c2a25d Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 21 Oct 2024 01:18:21 +0900 Subject: [PATCH 32/45] Add meaningful test for build --- bindings/ruby/tests/test_package.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/ruby/tests/test_package.rb b/bindings/ruby/tests/test_package.rb index 9d7527340f2..adaeedfbbae 100644 --- a/bindings/ruby/tests/test_package.rb +++ b/bindings/ruby/tests/test_package.rb @@ -7,7 +7,7 @@ class TestPackage < Test::Unit::TestCase def test_build Tempfile.create do |file| assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) - assert_path_exist file.to_path + assert file.size > 0 end end From 30c00c1375172c9a0c1e9aa53ab29ef0a9296aa4 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 21 Oct 2024 03:52:21 +0900 Subject: [PATCH 33/45] Rename: new_segment_callback_user_data -> new_segment_callback_container --- bindings/ruby/ext/ruby_whisper.cpp | 28 ++++++++++++++-------------- bindings/ruby/ext/ruby_whisper.h | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 8dd935dda08..0722273364a 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -86,8 +86,8 @@ void rb_whisper_free(ruby_whisper *rw) { } void rb_whisper_params_mark(ruby_whisper_params *rwp) { - rb_gc_mark(rwp->new_segment_callback_user_data->user_data); - rb_gc_mark(rwp->new_segment_callback_user_data->callback); + rb_gc_mark(rwp->new_segment_callback_container->user_data); + rb_gc_mark(rwp->new_segment_callback_container->callback); } void rb_whisper_params_free(ruby_whisper_params *rwp) { @@ -105,14 +105,14 @@ static VALUE ruby_whisper_allocate(VALUE klass) { static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; - ruby_whisper_callback_user_data *new_segment_callback_user_data; + ruby_whisper_callback_container *new_segment_callback_container; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - new_segment_callback_user_data = ALLOC(ruby_whisper_callback_user_data); - new_segment_callback_user_data->context = nullptr; - new_segment_callback_user_data->user_data = Qnil; - new_segment_callback_user_data->callback = Qnil; - rwp->new_segment_callback_user_data = new_segment_callback_user_data; + new_segment_callback_container = ALLOC(ruby_whisper_callback_container); + new_segment_callback_container->context = nullptr; + new_segment_callback_container->user_data = Qnil; + new_segment_callback_container->callback = Qnil; + rwp->new_segment_callback_container = new_segment_callback_container; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -247,16 +247,16 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - if (!NIL_P(rwp->new_segment_callback_user_data->callback)) { + if (!NIL_P(rwp->new_segment_callback_container->callback)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { - const ruby_whisper_callback_user_data *container = (ruby_whisper_callback_user_data *)user_data; + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; // Currently, doesn't support state because // those require to resolve GC-related problems. rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); }; - rwp->new_segment_callback_user_data->context = &self; - rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_user_data; + rwp->new_segment_callback_container->context = &self; + rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { @@ -473,13 +473,13 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->new_segment_callback_user_data->callback = value; + rwp->new_segment_callback_container->callback = value; return value; } static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->new_segment_callback_user_data->user_data = value; + rwp->new_segment_callback_container->user_data = value; return value; } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 033b780e94b..0e503a3df3b 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -7,7 +7,7 @@ typedef struct { VALUE *context; VALUE user_data; VALUE callback; -} ruby_whisper_callback_user_data; +} ruby_whisper_callback_container; typedef struct { struct whisper_context *context; @@ -16,7 +16,7 @@ typedef struct { typedef struct { struct whisper_full_params params; bool diarize; - ruby_whisper_callback_user_data *new_segment_callback_user_data; + ruby_whisper_callback_container *new_segment_callback_container; } ruby_whisper_params; #endif From e7f75f15de787edae97b18cd0e8bc10c82e2a340 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 21 Oct 2024 16:07:32 +0900 Subject: [PATCH 34/45] Add tests for Whisper::Segment --- bindings/ruby/tests/test_segment.rb | 55 +++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 bindings/ruby/tests/test_segment.rb diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb new file mode 100644 index 00000000000..037bf35e23e --- /dev/null +++ b/bindings/ruby/tests/test_segment.rb @@ -0,0 +1,55 @@ +require "test/unit" +require "whisper" + +class TestSegment < Test::Unit::TestCase + TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) + + class << self + attr_reader :whisper + + def startup + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) + end + end + + def test_iteration + whisper.each_segment do |segment| + assert_instance_of Whisper::Segment, segment + end + end + + def test_enumerator + enum = whisper.each_segment + assert_instance_of Enumerator, enum + enum.to_a.each_with_index do |segment, index| + assert_instance_of Whisper::Segment, segment + assert_kind_of Integer, index + end + end + + def test_start_time + i = 0 + whisper.each_segment do |segment| + assert_equal 0, segment.start_time if i == 0 + i += 1 + end + end + + def test_end_time + i = 0 + whisper.each_segment do |segment| + assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time + i += 1 + end + end + + private + + def whisper + self.class.whisper + end +end From 326055a926c27daa3dd378dde3606f386061050a Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 21 Oct 2024 16:07:48 +0900 Subject: [PATCH 35/45] Add Whisper::Segment and Whisper::Context#each_segment --- bindings/ruby/ext/ruby_whisper.cpp | 92 ++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 0722273364a..9ce98da9117 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -483,6 +483,88 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, return value; } +// High level API + +typedef struct { + VALUE context; + int index; +} ruby_whisper_segment; + +VALUE cSegment; + +static void rb_whisper_segment_mark(ruby_whisper_segment *rws) { + rb_gc_mark(rws->context); +} + +static VALUE ruby_whisper_segment_allocate(VALUE klass) { + ruby_whisper_segment *rws; + rws = ALLOC(ruby_whisper_segment); + return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws); +} + +static VALUE rb_whisper_segment_initialize(VALUE context, int index) { + ruby_whisper_segment *rws; + const VALUE segment = ruby_whisper_segment_allocate(cSegment); + Data_Get_Struct(segment, ruby_whisper_segment, rws); + rws->context = context; + rws->index = index; + return segment; +}; + +static VALUE ruby_whisper_each_segment(VALUE self) { + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, rb_intern("__method__"), 0); + return rb_funcall(self, rb_intern("to_enum"), 1, method_name); + } + + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + + const int n_segments = whisper_full_n_segments(rw->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(rb_whisper_segment_initialize(self, i)); + } + + return self; +} + +static VALUE ruby_whisper_segment_get_start_time(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index); + // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it + return INT2NUM(t0 * 10); +} + +static VALUE ruby_whisper_segment_get_end_time(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index); + // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it + return INT2NUM(t1 * 10); +} + +static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; +} + +static VALUE ruby_whisper_segment_get_text(VALUE self) { + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const char * text = whisper_full_get_segment_text(rw->context, rws->index); + return rb_str_new2(text); +} + void Init_whisper() { mWhisper = rb_define_module("Whisper"); cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); @@ -543,6 +625,16 @@ void Init_whisper() { rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); + + // High leve + cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject); + + rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); + rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); + rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); + rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); + rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); + rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); } #ifdef __cplusplus } From 1132c9e39d465102942506be4376487382c72a2e Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 21 Oct 2024 21:36:25 +0900 Subject: [PATCH 36/45] Extract c_ruby_whisper_callback_container_allocate() --- bindings/ruby/ext/ruby_whisper.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9ce98da9117..8ba9027aa4f 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -73,6 +73,7 @@ static void ruby_whisper_free(ruby_whisper *rw) { rw->context = NULL; } } + static void ruby_whisper_params_free(ruby_whisper_params *rwp) { } @@ -103,17 +104,20 @@ static VALUE ruby_whisper_allocate(VALUE klass) { return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); } +static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() { + ruby_whisper_callback_container *container; + container = ALLOC(ruby_whisper_callback_container); + container->context = nullptr; + container->user_data = Qnil; + container->callback = Qnil; + return container; +} + static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; - ruby_whisper_callback_container *new_segment_callback_container; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - new_segment_callback_container = ALLOC(ruby_whisper_callback_container); - new_segment_callback_container->context = nullptr; - new_segment_callback_container->user_data = Qnil; - new_segment_callback_container->callback = Qnil; - rwp->new_segment_callback_container = new_segment_callback_container; - + rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } From ba0fbec63ac8c2ee0c62e38303f0da7c7adfdcbf Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Wed, 23 Oct 2024 05:26:21 +0900 Subject: [PATCH 37/45] Add test for Whisper::Params#on_new_segment --- bindings/ruby/tests/test_segment.rb | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb index 037bf35e23e..debc934a860 100644 --- a/bindings/ruby/tests/test_segment.rb +++ b/bindings/ruby/tests/test_segment.rb @@ -47,6 +47,24 @@ def test_end_time end end + def test_on_new_segment + params = Whisper::Params.new + seg = nil + index = 0 + params.on_new_segment do |segment| + assert_instance_of Whisper::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text + end + index += 1 + end + whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params) + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text + end + private def whisper From 56c2dfd028f5dcb384650bd19872da574b2450a7 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Wed, 23 Oct 2024 05:26:40 +0900 Subject: [PATCH 38/45] Add Whisper::Params#on_new_egment --- bindings/ruby/ext/ruby_whisper.cpp | 33 ++++++++++++++++++++++++++++-- bindings/ruby/ext/ruby_whisper.h | 1 + 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 8ba9027aa4f..691246f5011 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -89,6 +89,7 @@ void rb_whisper_free(ruby_whisper *rw) { void rb_whisper_params_mark(ruby_whisper_params *rwp) { rb_gc_mark(rwp->new_segment_callback_container->user_data); rb_gc_mark(rwp->new_segment_callback_container->callback); + rb_gc_mark(rwp->new_segment_callback_container->callbacks); } void rb_whisper_params_free(ruby_whisper_params *rwp) { @@ -110,6 +111,7 @@ static ruby_whisper_callback_container * rb_whisper_callback_container_allocate( container->context = nullptr; container->user_data = Qnil; container->callback = Qnil; + container->callbacks = rb_ary_new(); return container; } @@ -139,6 +141,9 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { return self; } +// High level API +static VALUE rb_whisper_segment_initialize(VALUE context, int index); + /* * transcribe a single file * can emit to a block results @@ -251,13 +256,28 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rwp->params.encoder_begin_callback_user_data = &is_aborted; } - if (!NIL_P(rwp->new_segment_callback_container->callback)) { + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; // Currently, doesn't support state because // those require to resolve GC-related problems. - rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + const int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, rb_intern("call"), 1, segment); + } + } }; rwp->new_segment_callback_container->context = &self; rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; @@ -532,6 +552,14 @@ static VALUE ruby_whisper_each_segment(VALUE self) { return self; } +static VALUE ruby_whisper_params_on_new_segment(VALUE self) { + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->new_segment_callback_container->callbacks, blk); + return Qnil; +} + static VALUE ruby_whisper_segment_get_start_time(VALUE self) { ruby_whisper_segment *rws; Data_Get_Struct(self, ruby_whisper_segment, rws); @@ -635,6 +663,7 @@ void Init_whisper() { rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); + rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 0e503a3df3b..f3210a4ba48 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -7,6 +7,7 @@ typedef struct { VALUE *context; VALUE user_data; VALUE callback; + VALUE callbacks; } ruby_whisper_callback_container; typedef struct { From 87797cc62fed404ffbfc8e1aaf2e784a5bd5c2dd Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Wed, 23 Oct 2024 16:27:16 +0900 Subject: [PATCH 39/45] Assign symbol IDs to variables --- bindings/ruby/ext/ruby_whisper.cpp | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 691246f5011..002b118f716 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -36,6 +36,11 @@ VALUE mWhisper; VALUE cContext; VALUE cParams; +static ID id_to_s; +static ID id_call; +static ID id___method__; +static ID id_to_enum; + static VALUE ruby_whisper_s_lang_max_id(VALUE self) { return INT2NUM(whisper_lang_max_id()); } @@ -131,7 +136,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { rb_scan_args(argc, argv, "01", &whisper_model_file_path); Data_Get_Struct(self, ruby_whisper, rw); - if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { + if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); @@ -158,7 +163,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { Data_Get_Struct(self, ruby_whisper, rw); Data_Get_Struct(params, ruby_whisper_params, rwp); - if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) { + if (!rb_respond_to(wave_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); } @@ -263,7 +268,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { - rb_funcall(container->callback, rb_intern("call"), 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { @@ -275,7 +280,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, rb_intern("call"), 1, segment); + rb_funcall(cb, id_call, 1, segment); } } }; @@ -293,7 +298,7 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { const char * text = whisper_full_get_segment_text(rw->context, i); output = rb_str_concat(output, rb_str_new2(text)); } - VALUE idCall = rb_intern("call"); + VALUE idCall = id_call; if (blk != Qnil) { rb_funcall(blk, idCall, 1, output); } @@ -537,8 +542,8 @@ static VALUE rb_whisper_segment_initialize(VALUE context, int index) { static VALUE ruby_whisper_each_segment(VALUE self) { if (!rb_block_given_p()) { - const VALUE method_name = rb_funcall(self, rb_intern("__method__"), 0); - return rb_funcall(self, rb_intern("to_enum"), 1, method_name); + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); } ruby_whisper *rw; @@ -598,6 +603,11 @@ static VALUE ruby_whisper_segment_get_text(VALUE self) { } void Init_whisper() { + id_to_s = rb_intern("to_s"); + id_call = rb_intern("call"); + id___method__ = rb_intern("__method__"); + id_to_enum = rb_intern("to_enum"); + mWhisper = rb_define_module("Whisper"); cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); From d94710ca7f90159ea970669535c3136ee98b0e43 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Wed, 23 Oct 2024 17:53:20 +0900 Subject: [PATCH 40/45] Make extsources.yaml simpler --- bindings/ruby/Rakefile | 17 ++++----- bindings/ruby/extsources.yaml | 63 ++++++++++++++------------------ bindings/ruby/whispercpp.gemspec | 10 ++++- 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 9b2787e904c..5a6a9167a9f 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -5,17 +5,16 @@ require "yaml" require "rake/testtask" extsources = YAML.load_file("extsources.yaml") -extsources.each_pair do |src_dir, dests| - dests.each do |dest| - src = Pathname(src_dir)/File.basename(dest) - - file src - file dest => src do |t| - cp t.source, t.name - end +SOURCES = FileList[] +extsources.each do |src| + basename = src.pathmap("%f") + dest = basename == "LICENSE" ? basename : basename.pathmap("ext/%f") + file src + file dest => src do |t| + cp t.source, t.name end + SOURCES.include dest end -SOURCES = extsources.values.flatten CLEAN.include SOURCES CLEAN.include FileList[ "ext/*.o", diff --git a/bindings/ruby/extsources.yaml b/bindings/ruby/extsources.yaml index 52f0330d88d..85488864a18 100644 --- a/bindings/ruby/extsources.yaml +++ b/bindings/ruby/extsources.yaml @@ -1,36 +1,29 @@ --- -../../src: -- ext/whisper.cpp -../../include: -- ext/whisper.h -../../ggml/src: -- ext/ggml.c -- ext/ggml-impl.h -- ext/ggml-aarch64.h -- ext/ggml-aarch64.c -- ext/ggml-alloc.c -- ext/ggml-backend-impl.h -- ext/ggml-backend.cpp -- ext/ggml-common.h -- ext/ggml-quants.h -- ext/ggml-quants.c -- ext/ggml-cpu-impl.h -- ext/ggml-metal.m -- ext/ggml-metal.metal -- ext/ggml-blas.cpp -../../ggml/include: -- ext/ggml.h -- ext/ggml-alloc.h -- ext/ggml-backend.h -- ext/ggml-cuda.h -- ext/ggml-kompute.h -- ext/ggml-metal.h -- ext/ggml-sycl.h -- ext/ggml-vulkan.h -- ext/ggml-blas.h -../../scripts: -- ext/get-flags.mk -../../examples: -- ext/dr_wav.h -../..: -- LICENSE +- ../../src/whisper.cpp +- ../../include/whisper.h +- ../../ggml/src/ggml.c +- ../../ggml/src/ggml-impl.h +- ../../ggml/src/ggml-aarch64.h +- ../../ggml/src/ggml-aarch64.c +- ../../ggml/src/ggml-alloc.c +- ../../ggml/src/ggml-backend-impl.h +- ../../ggml/src/ggml-backend.cpp +- ../../ggml/src/ggml-common.h +- ../../ggml/src/ggml-quants.h +- ../../ggml/src/ggml-quants.c +- ../../ggml/src/ggml-cpu-impl.h +- ../../ggml/src/ggml-metal.m +- ../../ggml/src/ggml-metal.metal +- ../../ggml/src/ggml-blas.cpp +- ../../ggml/include/ggml.h +- ../../ggml/include/ggml-alloc.h +- ../../ggml/include/ggml-backend.h +- ../../ggml/include/ggml-cuda.h +- ../../ggml/include/ggml-kompute.h +- ../../ggml/include/ggml-metal.h +- ../../ggml/include/ggml-sycl.h +- ../../ggml/include/ggml-vulkan.h +- ../../ggml/include/ggml-blas.h +- ../../scripts/get-flags.mk +- ../../examples/dr_wav.h +- ../../LICENSE diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 5b24d7e795f..251d03faa06 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -9,7 +9,15 @@ Gem::Specification.new do |s| s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] - s.files = `git ls-files . -z`.split("\x0") + YAML.load_file("extsources.yaml").values.flatten + s.files = `git ls-files . -z`.split("\x0") + + YAML.load_file("extsources.yaml").collect {|file| + basename = File.basename(file) + if s.extra_rdoc_files.include?(basename) + basename + else + File.join("ext", basename) + end + } s.summary = %q{Ruby whisper.cpp bindings} s.test_files = ["tests/test_whisper.rb"] From a0cfc229c4bf202a7a6ce1747312f529dc5f94a8 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Thu, 24 Oct 2024 05:26:15 +0900 Subject: [PATCH 41/45] Update README --- bindings/ruby/README.md | 57 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 4c6d0e86587..29dba120d26 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -19,8 +19,6 @@ If bundler is not being used to manage dependencies, install the gem by executin Usage ----- -NOTE: This gem is still in development. API is not stable for now. - ```ruby require "whisper" @@ -33,9 +31,6 @@ params.duration = 60_000 params.max_text_tokens = 300 params.translate = true params.print_timestamps = false -params.new_segment_callback = ->(output, t0, t1, index) { - puts "segment #{index}: #{t0}ms -> #{t1}ms: #{output}" -} whisper.transcribe("path/to/audio.wav", params) do |whole_text| puts whole_text @@ -59,5 +54,57 @@ There are some types of models. See [models][] page for details. Currently, whisper.cpp accepts only 16-bit WAV files. +### API ### + +Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: + +```ruby +def format_time(time_ms) + sec, decimal_part = time_ms.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part] +end + +whisper.transcribe("path/to/audio.wav", params) + +whisper.each_segment.with_index do |segment, index| + line = "[%{nth}: %{st} --> %{ed}] %{text}" % { + nth: index + 1, + st: format_time(segment.start_time), + ed: format_time(segment.end_time), + text: segment.text + } + line << " (speaker turned)" if segment.speaker_next_turn? + puts line +end + +``` + +You can also add hook to params called on new segment: + +```ruby +def format_time(time_ms) + sec, decimal_part = time_ms.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part] +end + +# Add hook before calling #transcribe +params.on_new_segment do |segment| + line = "[%{st} --> %{ed}] %{text}" % { + st: format_time(segment.start_time), + ed: format_time(segment.end_time), + text: segment.text + } + line << " (speaker turned)" if segment.speaker_next_turn? + puts line +end + +whisper.transcribe("path/to/audio.wav", params) + +``` + [whisper.cpp]: https://github.com/ggerganov/whisper.cpp [models]: https://github.com/ggerganov/whisper.cpp/tree/master/models From 2f925ca19dd39c3e329c2a306bd7796e924a32eb Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Thu, 24 Oct 2024 10:13:49 +0900 Subject: [PATCH 42/45] Add document comments --- bindings/ruby/ext/ruby_whisper.cpp | 292 +++++++++++++++++++++++++++++ 1 file changed, 292 insertions(+) diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 002b118f716..b17a6bca4be 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -41,10 +41,18 @@ static ID id_call; static ID id___method__; static ID id_to_enum; +/* + * call-seq: + * lang_max_id -> Integer + */ static VALUE ruby_whisper_s_lang_max_id(VALUE self) { return INT2NUM(whisper_lang_max_id()); } +/* + * call-seq: + * lang_id(lang_name) -> Integer + */ static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { const char * lang_str = StringValueCStr(lang); const int id = whisper_lang_id(lang_str); @@ -54,6 +62,10 @@ static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { return INT2NUM(id); } +/* + * call-seq: + * lang_str(lang_id) -> String + */ static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { const int lang_id = NUM2INT(id); const char * str = whisper_lang_str(lang_id); @@ -63,6 +75,10 @@ static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { return rb_str_new2(str); } +/* + * call-seq: + * lang_str(lang_id) -> String + */ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { const int lang_id = NUM2INT(id); const char * str_full = whisper_lang_str_full(lang_id); @@ -128,6 +144,10 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } +/* + * call-seq: + * new("path/to/model.bin") -> Whisper::Context + */ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; VALUE whisper_model_file_path; @@ -153,6 +173,14 @@ static VALUE rb_whisper_segment_initialize(VALUE context, int index); * transcribe a single file * can emit to a block results * + * params = Whisper::Params.new + * params.duration = 60_000 + * whisper.transcribe "path/to/audio.wav", params do |text| + * puts text + * end + * + * call-seq: + * transcribe(path_to_audio, params) {|text| ...} **/ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; @@ -305,12 +333,24 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { return self; } +/* + * Number of segments. + * + * call-seq: + * full_n_segments -> Integer + */ static VALUE ruby_whisper_full_n_segments(VALUE self) { ruby_whisper *rw; Data_Get_Struct(self, ruby_whisper, rw); return INT2NUM(whisper_full_n_segments(rw->context)); } +/* + * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full. + * + * call-seq: + * full_lang_id -> Integer + */ static VALUE ruby_whisper_full_lang_id(VALUE self) { ruby_whisper *rw; Data_Get_Struct(self, ruby_whisper, rw); @@ -325,6 +365,14 @@ static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const return c_i_segment; } +/* + * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + * + * full_get_segment_t0(3) # => 1668 (16680 ms) + * + * call-seq: + * full_get_segment_t0(segment_index) -> Integer + */ static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { ruby_whisper *rw; Data_Get_Struct(self, ruby_whisper, rw); @@ -333,6 +381,14 @@ static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { return INT2NUM(t0); } +/* + * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + * + * full_get_segment_t1(3) # => 1668 (16680 ms) + * + * call-seq: + * full_get_segment_t1(segment_index) -> Integer + */ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { ruby_whisper *rw; Data_Get_Struct(self, ruby_whisper, rw); @@ -341,6 +397,14 @@ static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { return INT2NUM(t1); } +/* + * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + * + * full_get_segment_speacker_turn_next(3) # => true + * + * call-seq: + * full_get_segment_speacker_turn_next(segment_index) -> bool + */ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) { ruby_whisper *rw; Data_Get_Struct(self, ruby_whisper, rw); @@ -349,6 +413,14 @@ static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i return speaker_turn_next ? Qtrue : Qfalse; } +/* + * Text of a segment indexed by +segment_index+. + * + * full_get_segment_text(3) # => "ask not what your country can do for you, ..." + * + * call-seq: + * full_get_segment_text(segment_index) -> String + */ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { ruby_whisper *rw; Data_Get_Struct(self, ruby_whisper, rw); @@ -359,6 +431,9 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { /* * params.language = "auto" | "en", etc... + * + * call-seq: + * language = lang_name -> lang_name */ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; @@ -370,6 +445,10 @@ static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { } return value; } +/* + * call-seq: + * language -> String + */ static VALUE ruby_whisper_params_get_language(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -379,72 +458,185 @@ static VALUE ruby_whisper_params_get_language(VALUE self) { return rb_str_new2("auto"); } } +/* + * call-seq: + * translate = do_translate -> do_translate + */ static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, translate, value) } +/* + * call-seq: + * translate -> bool + */ static VALUE ruby_whisper_params_get_translate(VALUE self) { BOOL_PARAMS_GETTER(self, translate) } +/* + * call-seq: + * no_context = dont_use_context -> dont_use_context + */ static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, no_context, value) } +/* + * If true, does not use past transcription (if any) as initial prompt for the decoder. + * + * call-seq: + * no_context -> bool + */ static VALUE ruby_whisper_params_get_no_context(VALUE self) { BOOL_PARAMS_GETTER(self, no_context) } +/* + * call-seq: + * single_segment = force_single -> force_single + */ static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, single_segment, value) } +/* + * If true, forces single segment output (useful for streaming). + * + * call-seq: + * single_segment -> bool + */ static VALUE ruby_whisper_params_get_single_segment(VALUE self) { BOOL_PARAMS_GETTER(self, single_segment) } +/* + * call-seq: + * print_special = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_special, value) } +/* + * If true, prints special tokens (e.g. , , , etc.). + * + * call-seq: + * print_special -> bool + */ static VALUE ruby_whisper_params_get_print_special(VALUE self) { BOOL_PARAMS_GETTER(self, print_special) } +/* + * call-seq: + * print_progress = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_progress, value) } +/* + * If true, prints progress information. + * + * call-seq: + * print_progress -> bool + */ static VALUE ruby_whisper_params_get_print_progress(VALUE self) { BOOL_PARAMS_GETTER(self, print_progress) } +/* + * call-seq: + * print_realtime = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_realtime, value) } +/* + * If true, prints results from within whisper.cpp. (avoid it, use callback instead) + * call-seq: + * print_realtime -> bool + */ static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { BOOL_PARAMS_GETTER(self, print_realtime) } +/* + * call-seq: + * print_timestamps = force_print -> force_print + */ static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, print_timestamps, value) } +/* + * If true, prints timestamps for each text segment when printing realtime. + * + * call-seq: + * print_timestamps -> bool + */ static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { BOOL_PARAMS_GETTER(self, print_timestamps) } +/* + * call-seq: + * suppress_blank = force_suppress -> force_suppress + */ static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, suppress_blank, value) } +/* + * If true, suppresses blank outputs. + * + * call-seq: + * suppress_blank -> bool + */ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { BOOL_PARAMS_GETTER(self, suppress_blank) } +/* + * call-seq: + * suppress_non_speech_tokens = force_suppress -> force_suppress + */ static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) } +/* + * If true, suppresses non-speech-tokens. + * + * call-seq: + * suppress_non_speech_tokens -> bool + */ static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) } +/* + * If true, enables token-level timestamps. + * + * call-seq: + * token_timestamps -> bool + */ static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { BOOL_PARAMS_GETTER(self, token_timestamps) } +/* + * call-seq: + * token_timestamps = force_timestamps -> force_timestamps + */ static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, token_timestamps, value) } +/* + * If true, split on word rather than on token (when used with max_len). + * + * call-seq: + * translate -> bool + */ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { BOOL_PARAMS_GETTER(self, split_on_word) } +/* + * call-seq: + * split_on_word = force_split -> force_split + */ static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, split_on_word, value) } +/* + * If true, enables diarization. + * + * call-seq: + * diarize -> bool + */ static VALUE ruby_whisper_params_get_diarize(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -454,6 +646,10 @@ static VALUE ruby_whisper_params_get_diarize(VALUE self) { return Qfalse; } } +/* + * call-seq: + * diarize = force_diarize -> force_diarize + */ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -465,22 +661,42 @@ static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { return value; } +/* + * Start offset in ms. + * + * call-seq: + * offset -> Integer + */ static VALUE ruby_whisper_params_get_offset(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); return INT2NUM(rwp->params.offset_ms); } +/* + * call-seq: + * offset = offset_ms -> offset_ms + */ static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); rwp->params.offset_ms = NUM2INT(value); return value; } +/* + * Audio duration to process in ms. + * + * call-seq: + * duration -> Integer + */ static VALUE ruby_whisper_params_get_duration(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); return INT2NUM(rwp->params.duration_ms); } +/* + * call-seq: + * duration = duration_ms -> duration_ms + */ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -488,23 +704,49 @@ static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { return value; } +/* + * Max tokens to use from past text as prompt for the decoder. + * + * call-seq: + * max_text_tokens -> Integer + */ static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); return INT2NUM(rwp->params.n_max_text_ctx); } +/* + * call-seq: + * max_text_tokens = n_tokens -> n_tokens + */ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); rwp->params.n_max_text_ctx = NUM2INT(value); return value; } +/* + * Sets new segment callback, called for every newly generated text segment. + * + * params.new_segment_callback = ->(context, _, n_new, user_data) { + * # ... + * } + * + * call-seq: + * new_segment_callback = callback -> callback + */ static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); rwp->new_segment_callback_container->callback = value; return value; } +/* + * Sets user data passed to the last argument of new segment callback. + * + * call-seq: + * new_segment_callback_user_data = user_data -> use_data + */ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { ruby_whisper_params *rwp; Data_Get_Struct(self, ruby_whisper_params, rwp); @@ -540,6 +782,24 @@ static VALUE rb_whisper_segment_initialize(VALUE context, int index) { return segment; }; +/* + * Yields each Whisper::Segment: + * + * whisper.transcribe("path/to/audio.wav", params) + * whisper.each_segment do |segment| + * puts segment.text + * end + * + * Returns an Enumerator if no block given: + * + * whisper.transcribe("path/to/audio.wav", params) + * enum = whisper.each_segment + * enum.to_a # => [#, ...] + * + * call-seq: + * each_segment {|segment| ... } + * each_segment -> Enumerator + */ static VALUE ruby_whisper_each_segment(VALUE self) { if (!rb_block_given_p()) { const VALUE method_name = rb_funcall(self, id___method__, 0); @@ -557,6 +817,16 @@ static VALUE ruby_whisper_each_segment(VALUE self) { return self; } +/* + * Hook called on new segment. Yields each Whisper::Segment. + * + * whisper.on_new_segment do |segment| + * # ... + * end + * + * call-seq: + * on_new_segment {|segment| ... } + */ static VALUE ruby_whisper_params_on_new_segment(VALUE self) { ruby_whisper_params *rws; Data_Get_Struct(self, ruby_whisper_params, rws); @@ -565,6 +835,12 @@ static VALUE ruby_whisper_params_on_new_segment(VALUE self) { return Qnil; } +/* + * Start time in milliseconds. + * + * call-seq: + * start_time -> Integer + */ static VALUE ruby_whisper_segment_get_start_time(VALUE self) { ruby_whisper_segment *rws; Data_Get_Struct(self, ruby_whisper_segment, rws); @@ -575,6 +851,12 @@ static VALUE ruby_whisper_segment_get_start_time(VALUE self) { return INT2NUM(t0 * 10); } +/* + * End time in milliseconds. + * + * call-seq: + * end_time -> Integer + */ static VALUE ruby_whisper_segment_get_end_time(VALUE self) { ruby_whisper_segment *rws; Data_Get_Struct(self, ruby_whisper_segment, rws); @@ -585,6 +867,12 @@ static VALUE ruby_whisper_segment_get_end_time(VALUE self) { return INT2NUM(t1 * 10); } +/* + * Whether the next segment is predicted as a speaker turn. + * + * call-seq: + * speaker_turn_next? -> bool + */ static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) { ruby_whisper_segment *rws; Data_Get_Struct(self, ruby_whisper_segment, rws); @@ -593,6 +881,10 @@ static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) { return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; } +/* + * call-seq: + * text -> String + */ static VALUE ruby_whisper_segment_get_text(VALUE self) { ruby_whisper_segment *rws; Data_Get_Struct(self, ruby_whisper_segment, rws); From c94c8d4c20b148ee7980fb845a8ab35af10a7afb Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Thu, 24 Oct 2024 17:39:12 +0900 Subject: [PATCH 43/45] Add test for calling Whisper::Params#on_new_segment multiple times --- bindings/ruby/tests/test_segment.rb | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb index debc934a860..f3ebc0e9c78 100644 --- a/bindings/ruby/tests/test_segment.rb +++ b/bindings/ruby/tests/test_segment.rb @@ -65,6 +65,20 @@ def test_on_new_segment assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text end + def test_on_new_segment_twice + params = Whisper::Params.new + seg = nil + params.on_new_segment do |segment| + seg = segment + return + end + params.on_new_segment do |segment| + assert_same seg, segment + return + end + whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params) + end + private def whisper From c5564fc19bb9fd8222519c9b1e0e3fe4374e1733 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 28 Oct 2024 21:08:45 +0900 Subject: [PATCH 44/45] Add file dependencies to GitHub actions config and .gitignore --- .github/workflows/bindings-ruby.yml | 10 ++++++++++ bindings/ruby/ext/.gitignore | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 2b9b57bf29d..d1d3c341bd2 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -16,6 +16,9 @@ on: - ggml/src/ggml-quants.h - ggml/src/ggml-quants.c - ggml/src/ggml-cpu-impl.h + - ggml/src/ggml-metal.m + - ggml/src/ggml-metal.metal + - ggml/src/ggml-blas.cpp - ggml/include/ggml.h - ggml/include/ggml-alloc.h - ggml/include/ggml-backend.h @@ -24,6 +27,8 @@ on: - ggml/include/ggml-metal.h - ggml/include/ggml-sycl.h - ggml/include/ggml-vulkan.h + - ggml/include/ggml-blas.h + - scripts/get-flags.mk - examples/dr_wav.h pull_request: paths: @@ -41,6 +46,9 @@ on: - ggml/src/ggml-quants.h - ggml/src/ggml-quants.c - ggml/src/ggml-cpu-impl.h + - ggml/src/ggml-metal.m + - ggml/src/ggml-metal.metal + - ggml/src/ggml-blas.cpp - ggml/include/ggml.h - ggml/include/ggml-alloc.h - ggml/include/ggml-backend.h @@ -49,6 +57,8 @@ on: - ggml/include/ggml-metal.h - ggml/include/ggml-sycl.h - ggml/include/ggml-vulkan.h + - ggml/include/ggml-blas.h + - scripts/get-flags.mk - examples/dr_wav.h jobs: diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index 3e99686670c..96779cc0885 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -11,6 +11,9 @@ ggml-backend.c ggml-backend.h ggml-common.h ggml-cpu-impl.h +ggml-metal.m +ggml-metal.metal +ggml-blas.cpp ggml-cuda.h ggml-impl.h ggml-kompute.h @@ -20,6 +23,8 @@ ggml-quants.c ggml-quants.h ggml-sycl.h ggml-vulkan.h +ggml-blas.h +get-flags.mk whisper.cpp whisper.h dr_wav.h From 7a6640a50b536d69fd4ef4991ac4af22a8eb90c8 Mon Sep 17 00:00:00 2001 From: Kitaiti Makoto Date: Mon, 28 Oct 2024 22:04:13 +0900 Subject: [PATCH 45/45] Add more files to ext/.gitignore --- bindings/ruby/ext/.gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index 96779cc0885..c9f31967840 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -13,6 +13,7 @@ ggml-common.h ggml-cpu-impl.h ggml-metal.m ggml-metal.metal +ggml-metal-embed.metal ggml-blas.cpp ggml-cuda.h ggml-impl.h @@ -28,6 +29,7 @@ get-flags.mk whisper.cpp whisper.h dr_wav.h +depend whisper.bundle whisper.so whisper.dll