diff --git a/bindings/ruby/.gitignore b/bindings/ruby/.gitignore index 6e3b3be0e24..e04a90a9c69 100644 --- a/bindings/ruby/.gitignore +++ b/bindings/ruby/.gitignore @@ -1,5 +1,3 @@ LICENSE pkg/ -lib/whisper.so -lib/whisper.bundle -lib/whisper.dll +lib/whisper.* diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 8492e4ed91b..13ff1f00ad1 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -60,10 +60,10 @@ You also can use shorthand for pre-converted models: whisper = Whisper::Context.new("base.en") ``` -You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`: +You can see the list of prepared model names by `Whisper::Model.pre_converted_models.keys`: ```ruby -puts Whisper::Model.preconverted_models.keys +puts Whisper::Model.pre_converted_models.keys # tiny # tiny.en # tiny-q5_1 @@ -87,8 +87,9 @@ whisper = Whisper::Context.new("path/to/your/model.bin") Or, you can download model files: ```ruby -model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin") -whisper = Whisper::Context.new(model_uri) +whisper = Whisper::Context.new("https://example.net/uri/of/your/model.bin") +# Or +whisper = Whisper::Context.new(URI("https://example.net/uri/of/your/model.bin")) ``` See [models][] page for details. @@ -222,6 +223,17 @@ end The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. +Development +----------- + + % git clone https://github.com/ggerganov/whisper.cpp.git + % cd whisper.cpp/bindings/ruby + % rake test + +First call of `rake test` builds an extension and downloads a model for testing. After that, you add tests in `tests` directory and modify `ext/ruby_whisper.cpp`. + +If something seems wrong on build, running `rake clean` solves some cases. + License ------- diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 88a4fd2c205..5979f208ec9 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -49,6 +49,7 @@ static ID id_length; static ID id_next; static ID id_new; static ID id_to_path; +static ID id_URI; static ID id_pre_converted_models; static bool is_log_callback_finalized = false; @@ -283,6 +284,17 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { if (!NIL_P(pre_converted_model)) { whisper_model_file_path = pre_converted_model; } + if (TYPE(whisper_model_file_path) == T_STRING) { + const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path); + if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) { + VALUE uri_class = rb_const_get(cModel, id_URI); + whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); + } + } + if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) { + VALUE uri_class = rb_const_get(cModel, id_URI); + whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); + } if (rb_respond_to(whisper_model_file_path, id_to_path)) { whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); } @@ -837,7 +849,7 @@ static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { /* * call-seq: - * full_get_segment_no_speech_prob -> Float + * full_get_segment_no_speech_prob(segment_index) -> Float */ static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) { ruby_whisper *rw; @@ -1755,7 +1767,7 @@ static VALUE ruby_whisper_c_model_type(VALUE self) { static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) { const int c_code = NUM2INT(code); - char *raw_message; + const char *raw_message; switch (c_code) { case -2: raw_message = "failed to compute log mel spectrogram"; @@ -1802,6 +1814,7 @@ void Init_whisper() { id_next = rb_intern("next"); id_new = rb_intern("new"); id_to_path = rb_intern("to_path"); + id_URI = rb_intern("URI"); id_pre_converted_models = rb_intern("pre_converted_models"); mWhisper = rb_define_module("Whisper"); @@ -1941,6 +1954,8 @@ void Init_whisper() { rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0); rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0); rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0); + + rb_require("whisper/model/uri"); } #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index a6771038e6f..21e36c491cf 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -1,5 +1,5 @@ -#ifndef __RUBY_WHISPER_H -#define __RUBY_WHISPER_H +#ifndef RUBY_WHISPER_H +#define RUBY_WHISPER_H #include "whisper.h" diff --git a/bindings/ruby/lib/whisper.rb b/bindings/ruby/lib/whisper.rb deleted file mode 100644 index 3a0b844e15c..00000000000 --- a/bindings/ruby/lib/whisper.rb +++ /dev/null @@ -1,2 +0,0 @@ -require "whisper.so" -require "whisper/model/uri" diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index fe5ed56b3fb..b43d90dd486 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -1,163 +1,163 @@ -require "whisper.so" require "uri" require "net/http" require "time" require "pathname" require "io/console/size" -class Whisper::Model - class URI - def initialize(uri) - @uri = URI(uri) - end +module Whisper + class Model + class URI + def initialize(uri) + @uri = URI(uri) + end - def to_path - cache - cache_path.to_path - end + def to_path + cache + cache_path.to_path + end - def clear_cache - path = cache_path - path.delete if path.exist? - end + def clear_cache + path = cache_path + path.delete if path.exist? + end - private + private - def cache_path - base_cache_dir/@uri.host/@uri.path[1..] - end + def cache_path + base_cache_dir/@uri.host/@uri.path[1..] + end - def base_cache_dir - base = case RUBY_PLATFORM - when /mswin|mingw/ - ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local" - when /darwin/ - Pathname(Dir.home)/"Library/Caches" - else - ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache" - end - base/"whisper.cpp" - end + def base_cache_dir + base = case RUBY_PLATFORM + when /mswin|mingw/ + ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local" + when /darwin/ + Pathname(Dir.home)/"Library/Caches" + else + ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache" + end + base/"whisper.cpp" + end - def cache - path = cache_path - headers = {} - headers["if-modified-since"] = path.mtime.httpdate if path.exist? - request @uri, headers - path - end + def cache + path = cache_path + headers = {} + headers["if-modified-since"] = path.mtime.httpdate if path.exist? + request @uri, headers + path + end - def request(uri, headers) - Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http| - request = Net::HTTP::Get.new(uri, headers) - http.request request do |response| - case response - when Net::HTTPNotModified + def request(uri, headers) + Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http| + request = Net::HTTP::Get.new(uri, headers) + http.request request do |response| + case response + when Net::HTTPNotModified # noop - when Net::HTTPOK - download response - when Net::HTTPRedirection - request URI(response["location"]), headers - else - return if headers.key?("if-modified-since") # Use cache file - - raise "#{response.code} #{response.message}\n#{response.body}" + when Net::HTTPOK + download response + when Net::HTTPRedirection + request URI(response["location"]), headers + else + return if headers.key?("if-modified-since") # Use cache file + + raise "#{response.code} #{response.message}\n#{response.body}" + end end end end - end - def download(response) - path = cache_path - path.dirname.mkpath unless path.dirname.exist? - downloading_path = Pathname("#{path}.downloading") - size = response.content_length - downloading_path.open "wb" do |file| - downloaded = 0 - response.read_body do |chunk| - file << chunk - downloaded += chunk.bytesize - show_progress downloaded, size + def download(response) + path = cache_path + path.dirname.mkpath unless path.dirname.exist? + downloading_path = Pathname("#{path}.downloading") + size = response.content_length + downloading_path.open "wb" do |file| + downloaded = 0 + response.read_body do |chunk| + file << chunk + downloaded += chunk.bytesize + show_progress downloaded, size + end + $stderr.puts end - $stderr.puts + downloading_path.rename path end - downloading_path.rename path - end - def show_progress(current, size) - progress_rate_available = size && $stderr.tty? + def show_progress(current, size) + progress_rate_available = size && $stderr.tty? - unless @prev - @prev = Time.now - $stderr.puts "Downloading #{@uri} to #{cache_path}" - end + unless @prev + @prev = Time.now + $stderr.puts "Downloading #{@uri} to #{cache_path}" + end - now = Time.now + now = Time.now - if progress_rate_available - return if now - @prev < 1 && current < size + if progress_rate_available + return if now - @prev < 1 && current < size - progress_width = 20 - progress = current.to_f / size - arrow_length = progress * progress_width - arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length) - line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})" - padding = ' ' * ($stderr.winsize[1] - line.size) - $stderr.print "\r#{line}#{padding}" - else - return if now - @prev < 1 + progress_width = 20 + progress = current.to_f / size + arrow_length = progress * progress_width + arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length) + line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})" + padding = ' ' * ($stderr.winsize[1] - line.size) + $stderr.print "\r#{line}#{padding}" + else + return if now - @prev < 1 - $stderr.print "." + $stderr.print "." + end + @prev = now end - @prev = now - end - def format_bytesize(bytesize) - return "0.0 B" if bytesize.zero? + def format_bytesize(bytesize) + return "0.0 B" if bytesize.zero? - units = %w[B KiB MiB GiB TiB] - exp = (Math.log(bytesize) / Math.log(1024)).to_i - format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp]) + units = %w[B KiB MiB GiB TiB] + exp = (Math.log(bytesize) / Math.log(1024)).to_i + format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp]) + end end - end - @pre_converted_models = {} - %w[ - tiny - tiny.en - tiny-q5_1 - tiny.en-q5_1 - tiny-q8_0 - base - base.en - base-q5_1 - base.en-q5_1 - base-q8_0 - small - small.en - small.en-tdrz - small-q5_1 - small.en-q5_1 - small-q8_0 - medium - medium.en - medium-q5_0 - medium.en-q5_0 - medium-q8_0 - large-v1 - large-v2 - large-v2-q5_0 - large-v2-q8_0 - large-v3 - large-v3-q5_0 - large-v3-turbo - large-v3-turbo-q5_0 - large-v3-turbo-q8_0 - ].each do |name| - @pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") - end - - class << self - attr_reader :pre_converted_models + @pre_converted_models = %w[ + tiny + tiny.en + tiny-q5_1 + tiny.en-q5_1 + tiny-q8_0 + base + base.en + base-q5_1 + base.en-q5_1 + base-q8_0 + small + small.en + small.en-tdrz + small-q5_1 + small.en-q5_1 + small-q8_0 + medium + medium.en + medium-q5_0 + medium.en-q5_0 + medium-q8_0 + large-v1 + large-v2 + large-v2-q5_0 + large-v2-q8_0 + large-v3 + large-v3-q5_0 + large-v3-turbo + large-v3-turbo-q5_0 + large-v3-turbo-q8_0 + ].each_with_object({}) {|name, models| + models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") + } + + class << self + attr_reader :pre_converted_models + end end end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs new file mode 100644 index 00000000000..aff2ae73ee8 --- /dev/null +++ b/bindings/ruby/sig/whisper.rbs @@ -0,0 +1,153 @@ +module Whisper + interface _Samples + def length: () -> Integer + def each: { (Float) -> void } -> void + end + + type log_callback = ^(Integer level, String message, Object user_data) -> void + type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void + type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void + type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish + + LOG_LEVEL_NONE: Integer + LOG_LEVEL_INFO: Integer + LOG_LEVEL_WARN: Integer + LOG_LEVEL_ERROR: Integer + LOG_LEVEL_DEBUG: Integer + LOG_LEVEL_CONT: Integer + + def self.lang_max_id: () -> Integer + def self.lang_id: (string name) -> Integer + def self.lang_str: (Integer id) -> String + def self.lang_str_full: (Integer id) -> String + def self.log_set=: (log_callback) -> log_callback + def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer + + class Context + def initialize: (string | _ToPath | ::URI::HTTP ) -> void + def transcribe: (string, Params) -> void + | (string, Params) { (String) -> void } -> void + def model_n_vocab: () -> Integer + def model_n_audio_ctx: () -> Integer + def model_n_audio_state: () -> Integer + def model_n_text_head: () -> Integer + def model_n_text_layer: () -> Integer + def model_n_mels: () -> Integer + def model_ftype: () -> Integer + def model_type: () -> String + def full_n_segments: () -> Integer + def full_lang_id: () -> Integer + def full_get_segment_t0: (Integer) -> Integer + def full_get_segment_t1: (Integer) -> Integer + def full_get_segment_speaker_turn_next: (Integer) -> (true | false) + def full_get_segment_text: (Integer) -> String + def full_get_segment_no_speech_prob: (Integer) -> Float + def full: (Params, Array[Float], ?Integer) -> void + | (Params, _Samples, ?Integer) -> void + def full_parallel: (Params, Array[Float], ?Integer) -> void + | (Params, _Samples, ?Integer) -> void + | (Params, _Samples, ?Integer?, Integer) -> void + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + def model: () -> Model + end + + class Params + def initialize: () -> void + def language=: (String) -> String # TODO: Enumerate lang names + def language: () -> String + def translate=: (boolish) -> boolish + def translate: () -> (true | false) + def no_context=: (boolish) -> boolish + def no_context: () -> (true | false) + def single_segment=: (boolish) -> boolish + def single_segment: () -> (true | false) + def print_special=: (boolish) -> boolish + def print_special: () -> (true | false) + def print_progress=: (boolish) -> boolish + def print_progress: () -> (true | false) + def print_realtime=: (boolish) -> boolish + def print_realtime: () -> (true | false) + def print_timestamps=: (boolish) -> boolish + def print_timestamps: () -> (true | false) + def suppress_blank=: (boolish) -> boolish + def suppress_blank: () -> (true | false) + def suppress_nst=: (boolish) -> boolish + def suppress_nst: () -> (true | false) + def token_timestamps=: (boolish) -> boolish + def token_timestamps: () -> (true | false) + def split_on_word=: (boolish) -> boolish + def split_on_word: () -> (true | false) + def initial_prompt=: (_ToS) -> _ToS + def initial_prompt: () -> String + def diarize=: (boolish) -> boolish + def diarize: () -> (true | false) + def offset=: (Integer) -> Integer + def offset: () -> Integer + def duration=: (Integer) -> Integer + def duration: () -> Integer + def max_text_tokens=: (Integer) -> Integer + def max_text_tokens: () -> Integer + def temperature=: (Float) -> Float + def temperature: () -> Float + def max_initial_ts=: (Float) -> Float + def max_initial_ts: () -> Float + def length_penalty=: (Float) -> Float + def length_penalty: () -> Float + def temperature_inc=: (Float) -> Float + def temperature_inc: () -> Float + def entropy_thold=: (Float) -> Float + def entropy_thold: () -> Float + def logprob_thold=: (Float) -> Float + def logprob_thold: () -> Float + def no_speech_thold=: (Float) -> Float + def no_speech_thold: () -> Float + def new_segment_callback=: (new_segment_callback) -> new_segment_callback + def new_segment_callback_user_data=: (Object) -> Object + def progress_callback=: (progress_callback) -> progress_callback + def progress_callback_user_data=: (Object) -> Object + def abort_callback=: (abort_callback) -> abort_callback + def abort_callback_user_data=: (Object) -> Object + def on_new_segment: { (Segment) -> void } -> void + def on_progress: { (Integer) -> void } -> void + def abort_on: { (Object) -> boolish } -> void + end + + class Model + def self.pre_converted_models: () -> Hash[String, Model::URI] + def initialize: () -> void + def n_vocab: () -> Integer + def n_audio_ctx: () -> Integer + def n_audio_state: () -> Integer + def n_audio_head: () -> Integer + def n_audio_layer: () -> Integer + def n_text_ctx: () -> Integer + def n_text_state: () -> Integer + def n_text_head: () -> Integer + def n_text_layer: () -> Integer + def n_mels: () -> Integer + def ftype: () -> Integer + def type: () -> String + + class URI + def initialize: (string | ::URI::HTTP) -> void + def to_path: -> String + def clear_cache: -> void + end + end + + class Segment + def initialize: () -> void + def start_time: () -> Integer + def end_time: () -> Integer + def speaker_next_turn?: () -> (true | false) + def text: () -> String + def no_speech_prob: () -> Float + end + + class Error < StandardError + attr_reader code: Integer + + def initialize: (Integer) -> void + end +end diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb index 1362fc469bf..df871e0e651 100644 --- a/bindings/ruby/tests/test_model.rb +++ b/bindings/ruby/tests/test_model.rb @@ -68,4 +68,42 @@ def test_auto_download assert_path_exist path assert_equal 147964211, File.size(path) end + + def test_uri_string + path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin" + whisper = Whisper::Context.new(path) + model = whisper.model + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end + + def test_uri + path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin") + whisper = Whisper::Context.new(path) + model = whisper.model + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end end