Skip to content

Commit 799eacd

Browse files
ruby : Add parallel transcription support (#3222)
* Fix indentation of code sample in document comment * Make Whisper::Context#transcribe able to run non-parallel * Add test for Whisper::Context#transcribe with parallel option * Follow signature API change of Context#transcribe * Remove useless variable assignment * Move simple usage up in README * Add need help section in README * Add document on Context#transcribe's parallel option in README * Update date * Fix signature of Context.new * Make Context#subscribe accept n_processors option * Make test follow #transcribe's change * Make RBS follow #transcribe's change * Add document for #transcribe's n_processors option * Rename test directory so that Rake tasks' default setting is used
1 parent 82f461e commit 799eacd

20 files changed

+107
-62
lines changed

bindings/ruby/README.md

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,6 @@ end
7070

7171
Some models are prepared up-front:
7272

73-
```ruby
74-
base_en = Whisper::Model.pre_converted_models["base.en"]
75-
whisper = Whisper::Context.new(base_en)
76-
```
77-
78-
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
79-
80-
```ruby
81-
Whisper::Model.pre_converted_models["base"].clear_cache
82-
```
83-
8473
You also can use shorthand for pre-converted models:
8574

8675
```ruby
@@ -105,6 +94,19 @@ puts Whisper::Model.pre_converted_models.keys
10594
# :
10695
```
10796

97+
You can also retrieve each model:
98+
99+
```ruby
100+
base_en = Whisper::Model.pre_converted_models["base.en"]
101+
whisper = Whisper::Context.new(base_en)
102+
```
103+
104+
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
105+
106+
```ruby
107+
Whisper::Model.pre_converted_models["base"].clear_cache
108+
```
109+
108110
You can also use local model files you prepared:
109111

110112
```ruby
@@ -163,6 +165,16 @@ For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisp
163165
API
164166
---
165167

168+
### Transcription ###
169+
170+
By default, `Whisper::Context#transcribe` works in a single thread. You can make it work in parallel by passing `n_processors` option:
171+
172+
```ruby
173+
whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors)
174+
```
175+
176+
Note that transcription occasionally might be low accuracy when it works in parallel.
177+
166178
### Segments ###
167179

168180
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
@@ -297,6 +309,11 @@ First call of `rake test` builds an extension and downloads a model for testing.
297309

298310
If something seems wrong on build, running `rake clean` solves some cases.
299311

312+
### Need help ###
313+
314+
* Windows support
315+
* Refinement of C/C++ code, especially memory management
316+
300317
License
301318
-------
302319

bindings/ruby/Rakefile

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,15 @@ file LIB_FILE => [SO_FILE, "lib"] do |t|
6767
end
6868
CLEAN.include LIB_FILE
6969

70-
Rake::TestTask.new do |t|
71-
t.test_files = FileList["tests/test_*.rb"]
72-
end
70+
Rake::TestTask.new
7371

74-
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
75-
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
76-
chdir "tests/jfk_reader" do
72+
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
73+
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
74+
chdir "test/jfk_reader" do
7775
ruby "extconf.rb"
7876
sh "make"
7977
end
8078
end
81-
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
79+
CLEAN.include "test/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
8280

8381
task test: [LIB_FILE, TEST_MEMORY_VIEW]

bindings/ruby/ext/ruby_whisper.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ID id_URI;
2424
ID id_pre_converted_models;
2525
ID id_coreml_compiled_models;
2626
ID id_cache;
27+
ID id_n_processors;
2728

2829
static bool is_log_callback_finalized = false;
2930

@@ -142,6 +143,7 @@ void Init_whisper() {
142143
id_pre_converted_models = rb_intern("pre_converted_models");
143144
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
144145
id_cache = rb_intern("cache");
146+
id_n_processors = rb_intern("n_processors");
145147

146148
mWhisper = rb_define_module("Whisper");
147149
mVAD = rb_define_module_under(mWhisper, "VAD");

bindings/ruby/ext/ruby_whisper_context.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ extern ID id_URI;
1313
extern ID id_pre_converted_models;
1414
extern ID id_coreml_compiled_models;
1515
extern ID id_cache;
16+
extern ID id_n_processors;
1617

1718
extern VALUE cContext;
1819
extern VALUE eError;
@@ -24,6 +25,8 @@ extern VALUE rb_whisper_model_s_new(VALUE context);
2425
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
2526
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
2627

28+
ID transcribe_option_names[1];
29+
2730
static void
2831
ruby_whisper_free(ruby_whisper *rw)
2932
{
@@ -633,6 +636,8 @@ init_ruby_whisper_context(VALUE *mWhisper)
633636
{
634637
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
635638

639+
transcribe_option_names[0] = id_n_processors;
640+
636641
rb_define_alloc_func(cContext, ruby_whisper_allocate);
637642
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
638643

bindings/ruby/ext/ruby_whisper_transcribe.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ extern const rb_data_type_t ruby_whisper_params_type;
1313

1414
extern ID id_to_s;
1515
extern ID id_call;
16+
extern ID transcribe_option_names[1];
1617

1718
extern void
1819
prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
@@ -34,9 +35,14 @@ VALUE
3435
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
3536
ruby_whisper *rw;
3637
ruby_whisper_params *rwp;
37-
VALUE wave_file_path, blk, params;
38+
VALUE wave_file_path, blk, params, kws;
39+
VALUE opts[1];
40+
41+
rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, &params, &kws, &blk);
42+
rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts);
43+
44+
int n_processors = opts[0] == Qundef ? 1 : NUM2INT(opts[0]);
3845

39-
rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
4046
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
4147
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
4248

@@ -66,7 +72,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
6672

6773
prepare_transcription(rwp, &self);
6874

69-
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
75+
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) {
7076
fprintf(stderr, "failed to process audio\n");
7177
return self;
7278
}
@@ -76,9 +82,8 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
7682
const char * text = whisper_full_get_segment_text(rw->context, i);
7783
output = rb_str_concat(output, rb_str_new2(text));
7884
}
79-
VALUE idCall = id_call;
8085
if (blk != Qnil) {
81-
rb_funcall(blk, idCall, 1, output);
86+
rb_funcall(blk, id_call, 1, output);
8287
}
8388
return self;
8489
}

bindings/ruby/sig/whisper.rbs

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ module Whisper
2525
def self.system_info_str: () -> String
2626

2727
class Context
28-
def self.new: (path | ::URI::HTTP) -> instance
28+
def self.new: (String | path | ::URI::HTTP) -> instance
2929

3030
# transcribe a single file
3131
# can emit to a block results
3232
#
33-
# params = Whisper::Params.new
34-
# params.duration = 60_000
35-
# whisper.transcribe "path/to/audio.wav", params do |text|
36-
# puts text
37-
# end
33+
# params = Whisper::Params.new
34+
# params.duration = 60_000
35+
# whisper.transcribe "path/to/audio.wav", params do |text|
36+
# puts text
37+
# end
3838
#
39-
def transcribe: (string, Params) -> self
40-
| (string, Params) { (String) -> void } -> self
39+
def transcribe: (string, Params, ?n_processors: Integer) -> self
40+
| (string, Params, ?n_processors: Integer) { (String) -> void } -> self
4141

4242
def model_n_vocab: () -> Integer
4343
def model_n_audio_ctx: () -> Integer
@@ -50,16 +50,16 @@ module Whisper
5050

5151
# Yields each Whisper::Segment:
5252
#
53-
# whisper.transcribe("path/to/audio.wav", params)
54-
# whisper.each_segment do |segment|
55-
# puts segment.text
56-
# end
53+
# whisper.transcribe("path/to/audio.wav", params)
54+
# whisper.each_segment do |segment|
55+
# puts segment.text
56+
# end
5757
#
5858
# Returns an Enumerator if no block given:
5959
#
60-
# whisper.transcribe("path/to/audio.wav", params)
61-
# enum = whisper.each_segment
62-
# enum.to_a # => [#<Whisper::Segment>, ...]
60+
# whisper.transcribe("path/to/audio.wav", params)
61+
# enum = whisper.each_segment
62+
# enum.to_a # => [#<Whisper::Segment>, ...]
6363
#
6464
def each_segment: { (Segment) -> void } -> void
6565
| () -> Enumerator[Segment]
@@ -74,25 +74,25 @@ module Whisper
7474

7575
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
7676
#
77-
# full_get_segment_t0(3) # => 1668 (16680 ms)
77+
# full_get_segment_t0(3) # => 1668 (16680 ms)
7878
#
7979
def full_get_segment_t0: (Integer) -> Integer
8080

8181
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
8282
#
83-
# full_get_segment_t1(3) # => 1668 (16680 ms)
83+
# full_get_segment_t1(3) # => 1668 (16680 ms)
8484
#
8585
def full_get_segment_t1: (Integer) -> Integer
8686

8787
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
8888
#
89-
# full_get_segment_speacker_turn_next(3) # => true
89+
# full_get_segment_speacker_turn_next(3) # => true
9090
#
9191
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
9292

9393
# Text of a segment indexed by +segment_index+.
9494
#
95-
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
95+
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
9696
#
9797
def full_get_segment_text: (Integer) -> String
9898

@@ -282,9 +282,9 @@ module Whisper
282282

283283
# Sets new segment callback, called for every newly generated text segment.
284284
#
285-
# params.new_segment_callback = ->(context, _, n_new, user_data) {
286-
# # ...
287-
# }
285+
# params.new_segment_callback = ->(context, _, n_new, user_data) {
286+
# # ...
287+
# }
288288
#
289289
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
290290
def new_segment_callback: () -> (new_segment_callback | nil)
@@ -297,9 +297,9 @@ module Whisper
297297

298298
# Sets progress callback, called on each progress update.
299299
#
300-
# params.new_segment_callback = ->(context, _, progress, user_data) {
301-
# # ...
302-
# }
300+
# params.new_segment_callback = ->(context, _, progress, user_data) {
301+
# # ...
302+
# }
303303
#
304304
# +progress+ is an Integer between 0 and 100.
305305
#
@@ -327,9 +327,9 @@ module Whisper
327327

328328
# Sets abort callback, called to check if the process should be aborted.
329329
#
330-
# params.abort_callback = ->(user_data) {
331-
# # ...
332-
# }
330+
# params.abort_callback = ->(user_data) {
331+
# # ...
332+
# }
333333
#
334334
#
335335
def abort_callback=: (abort_callback) -> abort_callback
@@ -358,9 +358,9 @@ module Whisper
358358

359359
# Hook called on new segment. Yields each Whisper::Segment.
360360
#
361-
# whisper.on_new_segment do |segment|
362-
# # ...
363-
# end
361+
# whisper.on_new_segment do |segment|
362+
# # ...
363+
# end
364364
#
365365
def on_new_segment: { (Segment) -> void } -> void
366366

@@ -374,13 +374,13 @@ module Whisper
374374

375375
# Call block to determine whether abort or not. Return +true+ when you want to abort.
376376
#
377-
# params.abort_on do
378-
# if some_condition
379-
# true # abort
380-
# else
381-
# false # continue
377+
# params.abort_on do
378+
# if some_condition
379+
# true # abort
380+
# else
381+
# false # continue
382+
# end
382383
# end
383-
# end
384384
#
385385
def abort_on: { (Object user_data) -> boolish } -> void
386386
end
File renamed without changes.

0 commit comments

Comments
 (0)