Skip to content

Commit 30f3c9d

Browse files
committed
Update type signatures
1 parent 744b64e commit 30f3c9d

File tree

2 files changed

+89
-50
lines changed

2 files changed

+89
-50
lines changed

bindings/ruby/sig/whisper.rbs

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ module Whisper
2424
def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer
2525

2626
class Context
27-
def initialize: (string | _ToPath | ::URI::HTTP ) -> void
27+
def initialize: (string | _ToPath | ::URI::HTTP) -> void
2828
def transcribe: (string, Params) -> void
2929
| (string, Params) { (String) -> void } -> void
3030
def model_n_vocab: () -> Integer
@@ -42,18 +42,49 @@ module Whisper
4242
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
4343
def full_get_segment_text: (Integer) -> String
4444
def full_get_segment_no_speech_prob: (Integer) -> Float
45-
def full: (Params, Array[Float], ?Integer) -> void
46-
| (Params, _Samples, ?Integer) -> void
47-
def full_parallel: (Params, Array[Float], ?Integer) -> void
48-
| (Params, _Samples, ?Integer) -> void
49-
| (Params, _Samples, ?Integer?, Integer) -> void
45+
def full: (Params, Array[Float] samples, ?Integer n_samples) -> void
46+
| (Params, _Samples, ?Integer n_samples) -> void
47+
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> void
48+
| (Params, _Samples, ?Integer n_samples) -> void
49+
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> void
5050
def each_segment: { (Segment) -> void } -> void
5151
| () -> Enumerator[Segment]
5252
def model: () -> Model
5353
end
5454

5555
class Params
56-
def initialize: () -> void
56+
def initialize: (
57+
?language: string,
58+
?translate: boolish,
59+
?no_context: boolish,
60+
?single_segment: boolish,
61+
?print_special: boolish,
62+
?print_progress: boolish,
63+
?print_realtime: boolish,
64+
?print_timestamps: boolish,
65+
?suppress_blank: boolish,
66+
?suppress_nst: boolish,
67+
?token_timestamps: boolish,
68+
?split_on_word: boolish,
69+
?initial_prompt: string | nil,
70+
?diarize: boolish,
71+
?offset: Integer,
72+
?duration: Integer,
73+
?max_text_tokens: Integer,
74+
?temperature: Float,
75+
?max_initial_ts: Float,
76+
?length_penalty: Float,
77+
?temperature_inc: Float,
78+
?entropy_thold: Float,
79+
?logprob_thold: Float,
80+
?no_speech_thold: Float,
81+
?new_segment_callback: new_segment_callback,
82+
?new_segment_callback_user_data: Object,
83+
?progress_callback: progress_callback,
84+
?progress_callback_user_data: Object,
85+
?abort_callback: abort_callback,
86+
?abort_callback_user_data: Object
87+
) -> void
5788
def language=: (String) -> String # TODO: Enumerate lang names
5889
def language: () -> String
5990
def translate=: (boolish) -> boolish
@@ -79,7 +110,7 @@ module Whisper
79110
def split_on_word=: (boolish) -> boolish
80111
def split_on_word: () -> (true | false)
81112
def initial_prompt=: (_ToS) -> _ToS
82-
def initial_prompt: () -> String
113+
def initial_prompt: () -> (String | nil)
83114
def diarize=: (boolish) -> boolish
84115
def diarize: () -> (true | false)
85116
def offset=: (Integer) -> Integer
@@ -103,14 +134,20 @@ module Whisper
103134
def no_speech_thold=: (Float) -> Float
104135
def no_speech_thold: () -> Float
105136
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
137+
def new_segment_callback: () -> (new_segment_callback | nil)
106138
def new_segment_callback_user_data=: (Object) -> Object
139+
def new_segment_callback_user_data: () -> Object
107140
def progress_callback=: (progress_callback) -> progress_callback
141+
def progress_callback: () -> (progress_callback | nil)
108142
def progress_callback_user_data=: (Object) -> Object
143+
def progress_callback_user_data: () -> Object
109144
def abort_callback=: (abort_callback) -> abort_callback
145+
def abort_callback: () -> (abort_callback | nil)
110146
def abort_callback_user_data=: (Object) -> Object
147+
def abort_callback_user_data: () -> Object
111148
def on_new_segment: { (Segment) -> void } -> void
112-
def on_progress: { (Integer) -> void } -> void
113-
def abort_on: { (Object) -> boolish } -> void
149+
def on_progress: { (Integer progress) -> void } -> void
150+
def abort_on: { (Object user_data) -> boolish } -> void
114151
end
115152

116153
class Model
@@ -148,6 +185,6 @@ module Whisper
148185
class Error < StandardError
149186
attr_reader code: Integer
150187

151-
def initialize: (Integer) -> void
188+
def initialize: (Integer code) -> void
152189
end
153190
end

bindings/ruby/tests/test_params.rb

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
11
require_relative "helper"
22

33
class TestParams < TestBase
4-
DEFAULT_VALUES = {
5-
language: "en",
6-
translate: false,
7-
no_context: true,
8-
single_segment: false,
9-
print_special: false,
10-
print_progress: true,
11-
print_realtime: false,
12-
print_timestamps: true,
13-
suppress_blank: true,
14-
suppress_nst: false,
15-
token_timestamps: false,
16-
split_on_word: false,
17-
initial_prompt: nil,
18-
diarize: false,
19-
offset: 0,
20-
duration: 0,
21-
max_text_tokens: 16384,
22-
temperature: 0.0,
23-
max_initial_ts: 1.0,
24-
length_penalty: -1.0,
25-
temperature_inc: 0.2,
26-
entropy_thold: 2.4,
27-
logprob_thold: -1.0,
28-
no_speech_thold: 0.6,
29-
new_segment_callback: nil,
30-
new_segment_callback_user_data: nil,
31-
progress_callback: nil,
32-
progress_callback_user_data: nil,
33-
abort_callback: nil,
34-
abort_callback_user_data: nil
35-
}
4+
PARAM_NAMES = [
5+
:language,
6+
:translate,
7+
:no_context,
8+
:single_segment,
9+
:print_special,
10+
:print_progress,
11+
:print_realtime,
12+
:print_timestamps,
13+
:suppress_blank,
14+
:suppress_nst,
15+
:token_timestamps,
16+
:split_on_word,
17+
:initial_prompt,
18+
:diarize,
19+
:offset,
20+
:duration,
21+
:max_text_tokens,
22+
:temperature,
23+
:max_initial_ts,
24+
:length_penalty,
25+
:temperature_inc,
26+
:entropy_thold,
27+
:logprob_thold,
28+
:no_speech_thold,
29+
:new_segment_callback,
30+
:new_segment_callback_user_data,
31+
:progress_callback,
32+
:progress_callback_user_data,
33+
:abort_callback,
34+
:abort_callback_user_data,
35+
]
3636

3737
def setup
3838
@params = Whisper::Params.new
@@ -209,9 +209,9 @@ def test_new_with_kw_args_wrong_type
209209
end
210210
end
211211

212-
data(DEFAULT_VALUES.collect {|param, value| [param, [param, value]]}.to_h)
213-
def test_new_with_kw_args_default_values(data)
214-
param, default_value = data
212+
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
213+
def test_new_with_kw_args_default_values(param)
214+
default_value = @params.send(param)
215215
value = case [param, default_value]
216216
in [*, true | false]
217217
!default_value
@@ -233,11 +233,13 @@ def test_new_with_kw_args_default_values(data)
233233
assert_equal value, params.send(param)
234234
end
235235

236-
DEFAULT_VALUES.except(param).each do |key, default_value|
237-
if Float === default_value
238-
assert_in_delta default_value, params.send(key)
236+
PARAM_NAMES.reject {|name| name == param}.each do |name|
237+
expected = @params.send(name)
238+
actual = params.send(name)
239+
if Float === expected
240+
assert_in_delta expected, actual
239241
else
240-
assert_equal default_value, params.send(key)
242+
assert_equal expected, actual
241243
end
242244
end
243245
end

0 commit comments

Comments
 (0)