Skip to content

Commit 744b64e

Browse files
committed
Make Whisper::Params.new accept keyword arguments
1 parent 2b2dd5e commit 744b64e

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

bindings/ruby/ext/ruby_whisper_params.c

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,46 @@
2020
return Qfalse; \
2121
}
2222

23+
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
24+
2325
extern VALUE cParams;
2426

2527
extern ID id_call;
2628

2729
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
2830

31+
static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
32+
static ID id_language;
33+
static ID id_translate;
34+
static ID id_no_context;
35+
static ID id_single_segment;
36+
static ID id_print_special;
37+
static ID id_print_progress;
38+
static ID id_print_realtime;
39+
static ID id_print_timestamps;
40+
static ID id_suppress_blank;
41+
static ID id_suppress_nst;
42+
static ID id_token_timestamps;
43+
static ID id_split_on_word;
44+
static ID id_initial_prompt;
45+
static ID id_diarize;
46+
static ID id_offset;
47+
static ID id_duration;
48+
static ID id_max_text_tokens;
49+
static ID id_temperature;
50+
static ID id_max_initial_ts;
51+
static ID id_length_penalty;
52+
static ID id_temperature_inc;
53+
static ID id_entropy_thold;
54+
static ID id_logprob_thold;
55+
static ID id_no_speech_thold;
56+
static ID id_new_segment_callback;
57+
static ID id_new_segment_callback_user_data;
58+
static ID id_progress_callback;
59+
static ID id_progress_callback_user_data;
60+
static ID id_abort_callback;
61+
static ID id_abort_callback_user_data;
62+
2963
static void
3064
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
3165
{
@@ -854,6 +888,78 @@ ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
854888
return value;
855889
}
856890

891+
#define SET_PARAM_IF_SAME(param_name) \
892+
if (id == id_ ## param_name) { \
893+
ruby_whisper_params_set_ ## param_name(self, value); \
894+
continue; \
895+
}
896+
897+
static VALUE
898+
ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
899+
{
900+
901+
VALUE kw_hash;
902+
VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
903+
VALUE value;
904+
ruby_whisper_params *rwp;
905+
ID id;
906+
int i;
907+
908+
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
909+
if (NIL_P(kw_hash)) {
910+
return self;
911+
}
912+
913+
rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
914+
Data_Get_Struct(self, ruby_whisper_params, rwp);
915+
916+
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
917+
id = param_names[i];
918+
value = values[i];
919+
if (value == Qundef) {
920+
continue;
921+
}
922+
if (id == id_diarize) {
923+
rwp->diarize = value;
924+
continue;
925+
} else {
926+
SET_PARAM_IF_SAME(language)
927+
SET_PARAM_IF_SAME(translate)
928+
SET_PARAM_IF_SAME(no_context)
929+
SET_PARAM_IF_SAME(single_segment)
930+
SET_PARAM_IF_SAME(print_special)
931+
SET_PARAM_IF_SAME(print_progress)
932+
SET_PARAM_IF_SAME(print_realtime)
933+
SET_PARAM_IF_SAME(print_timestamps)
934+
SET_PARAM_IF_SAME(suppress_blank)
935+
SET_PARAM_IF_SAME(suppress_nst)
936+
SET_PARAM_IF_SAME(token_timestamps)
937+
SET_PARAM_IF_SAME(split_on_word)
938+
SET_PARAM_IF_SAME(initial_prompt)
939+
SET_PARAM_IF_SAME(offset)
940+
SET_PARAM_IF_SAME(duration)
941+
SET_PARAM_IF_SAME(max_text_tokens)
942+
SET_PARAM_IF_SAME(temperature)
943+
SET_PARAM_IF_SAME(max_initial_ts)
944+
SET_PARAM_IF_SAME(length_penalty)
945+
SET_PARAM_IF_SAME(temperature_inc)
946+
SET_PARAM_IF_SAME(entropy_thold)
947+
SET_PARAM_IF_SAME(logprob_thold)
948+
SET_PARAM_IF_SAME(no_speech_thold)
949+
SET_PARAM_IF_SAME(new_segment_callback)
950+
SET_PARAM_IF_SAME(new_segment_callback_user_data)
951+
SET_PARAM_IF_SAME(progress_callback)
952+
SET_PARAM_IF_SAME(progress_callback_user_data)
953+
SET_PARAM_IF_SAME(abort_callback)
954+
SET_PARAM_IF_SAME(abort_callback_user_data)
955+
}
956+
}
957+
958+
return self;
959+
}
960+
961+
#undef SET_PARAM_IF_SAME
962+
857963
/*
858964
* Hook called on new segment. Yields each Whisper::Segment.
859965
*
@@ -921,9 +1027,73 @@ ruby_whisper_params_abort_on(VALUE self)
9211027
void
9221028
init_ruby_whisper_params(VALUE *mWhisper)
9231029
{
1030+
id_language = rb_intern("language");
1031+
id_translate = rb_intern("translate");
1032+
id_no_context = rb_intern("no_context");
1033+
id_single_segment = rb_intern("single_segment");
1034+
id_print_special = rb_intern("print_special");
1035+
id_print_progress = rb_intern("print_progress");
1036+
id_print_realtime = rb_intern("print_realtime");
1037+
id_print_timestamps = rb_intern("print_timestamps");
1038+
id_suppress_blank = rb_intern("suppress_blank");
1039+
id_suppress_nst = rb_intern("suppress_nst");
1040+
id_token_timestamps = rb_intern("token_timestamps");
1041+
id_split_on_word = rb_intern("split_on_word");
1042+
id_initial_prompt = rb_intern("initial_prompt");
1043+
id_diarize = rb_intern("diarize");
1044+
id_offset = rb_intern("offset");
1045+
id_duration = rb_intern("duration");
1046+
id_max_text_tokens = rb_intern("max_text_tokens");
1047+
id_temperature = rb_intern("temperature");
1048+
id_max_initial_ts = rb_intern("max_initial_ts");
1049+
id_length_penalty = rb_intern("length_penalty");
1050+
id_temperature_inc = rb_intern("temperature_inc");
1051+
id_entropy_thold = rb_intern("entropy_thold");
1052+
id_logprob_thold = rb_intern("logprob_thold");
1053+
id_no_speech_thold = rb_intern("no_speech_thold");
1054+
id_new_segment_callback = rb_intern("new_segment_callback");
1055+
id_new_segment_callback_user_data = rb_intern("new_segment_callback_user_data");
1056+
id_progress_callback = rb_intern("progress_callback");
1057+
id_progress_callback_user_data = rb_intern("progress_callback_user_data");
1058+
id_abort_callback = rb_intern("abort_callback");
1059+
id_abort_callback_user_data = rb_intern("abort_callback_user_data");
1060+
1061+
param_names[0] = id_language;
1062+
param_names[1] = id_translate;
1063+
param_names[2] = id_no_context;
1064+
param_names[3] = id_single_segment;
1065+
param_names[4] = id_print_special;
1066+
param_names[5] = id_print_progress;
1067+
param_names[6] = id_print_realtime;
1068+
param_names[7] = id_print_timestamps;
1069+
param_names[8] = id_suppress_blank;
1070+
param_names[9] = id_suppress_nst;
1071+
param_names[10] = id_token_timestamps;
1072+
param_names[11] = id_split_on_word;
1073+
param_names[12] = id_initial_prompt;
1074+
param_names[13] = id_diarize;
1075+
param_names[14] = id_offset;
1076+
param_names[15] = id_duration;
1077+
param_names[16] = id_max_text_tokens;
1078+
param_names[17] = id_temperature;
1079+
param_names[18] = id_max_initial_ts;
1080+
param_names[19] = id_length_penalty;
1081+
param_names[20] = id_temperature_inc;
1082+
param_names[21] = id_entropy_thold;
1083+
param_names[22] = id_logprob_thold;
1084+
param_names[23] = id_no_speech_thold;
1085+
param_names[24] = id_new_segment_callback;
1086+
param_names[25] = id_new_segment_callback_user_data;
1087+
param_names[26] = id_progress_callback;
1088+
param_names[27] = id_progress_callback_user_data;
1089+
param_names[28] = id_abort_callback;
1090+
param_names[29] = id_abort_callback_user_data;
1091+
9241092
cParams = rb_define_class_under(*mWhisper, "Params", rb_cObject);
9251093

9261094
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
1095+
rb_define_method(cParams, "initialize", ruby_whisper_params_initialize, -1);
1096+
9271097
rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1);
9281098
rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0);
9291099
rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1);

0 commit comments

Comments
 (0)