diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 8492e4ed91b..d11856d06e5 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -87,8 +87,9 @@ whisper = Whisper::Context.new("path/to/your/model.bin") Or, you can download model files: ```ruby -model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin") -whisper = Whisper::Context.new(model_uri) +whisper = Whisper::Context.new("https://example.net/uri/of/your/model.bin") +# Or +whisper = Whisper::Context.new(URI("https://example.net/uri/of/your/model.bin")) ``` See [models][] page for details. diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 88a4fd2c205..62ea04cbbd2 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -49,6 +49,7 @@ static ID id_length; static ID id_next; static ID id_new; static ID id_to_path; +static ID id_URI; static ID id_pre_converted_models; static bool is_log_callback_finalized = false; @@ -115,7 +116,7 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { VALUE old_callback = rb_iv_get(self, "log_callback"); if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); + rb_undefine_finalizer(self); } rb_iv_set(self, "log_callback", log_callback); @@ -283,6 +284,17 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { if (!NIL_P(pre_converted_model)) { whisper_model_file_path = pre_converted_model; } + if (TYPE(whisper_model_file_path) == T_STRING) { + const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path); + if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) { + VALUE uri_class = rb_const_get(cModel, id_URI); + whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); + } + } + if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) { + VALUE uri_class = rb_const_get(cModel, id_URI); + whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); + } if (rb_respond_to(whisper_model_file_path, id_to_path)) { whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); } @@ -1802,6 +1814,7 @@ void Init_whisper() { id_next = rb_intern("next"); id_new = rb_intern("new"); id_to_path = rb_intern("to_path"); + id_URI = rb_intern("URI"); id_pre_converted_models = rb_intern("pre_converted_models"); mWhisper = rb_define_module("Whisper"); diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb index 1362fc469bf..df871e0e651 100644 --- a/bindings/ruby/tests/test_model.rb +++ b/bindings/ruby/tests/test_model.rb @@ -68,4 +68,42 @@ def test_auto_download assert_path_exist path assert_equal 147964211, File.size(path) end + + def test_uri_string + path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin" + whisper = Whisper::Context.new(path) + model = whisper.model + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end + + def test_uri + path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin") + whisper = Whisper::Context.new(path) + model = whisper.model + + assert_equal 51864, model.n_vocab + assert_equal 1500, model.n_audio_ctx + assert_equal 512, model.n_audio_state + assert_equal 8, model.n_audio_head + assert_equal 6, model.n_audio_layer + assert_equal 448, model.n_text_ctx + assert_equal 512, model.n_text_state + assert_equal 8, model.n_text_head + assert_equal 6, model.n_text_layer + assert_equal 80, model.n_mels + assert_equal 1, model.ftype + assert_equal "base", model.type + end end