Skip to content

Commit 76002a5

Browse files
committed
Add progress and abort callback features
1 parent 820c721 commit 76002a5

File tree

2 files changed

+163
-3
lines changed

2 files changed

+163
-3
lines changed

bindings/ruby/ext/ruby_whisper.cpp

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,16 @@ void rb_whisper_free(ruby_whisper *rw) {
107107
free(rw);
108108
}
109109

110+
void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) {
111+
rb_gc_mark(rwc->user_data);
112+
rb_gc_mark(rwc->callback);
113+
rb_gc_mark(rwc->callbacks);
114+
}
115+
110116
void rb_whisper_params_mark(ruby_whisper_params *rwp) {
111-
rb_gc_mark(rwp->new_segment_callback_container->user_data);
112-
rb_gc_mark(rwp->new_segment_callback_container->callback);
113-
rb_gc_mark(rwp->new_segment_callback_container->callbacks);
117+
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
118+
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
119+
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
114120
}
115121

116122
void rb_whisper_params_free(ruby_whisper_params *rwp) {
@@ -141,6 +147,8 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
141147
rwp = ALLOC(ruby_whisper_params);
142148
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
143149
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
150+
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
151+
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
144152
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
145153
}
146154

@@ -316,6 +324,54 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
316324
rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container;
317325
}
318326

327+
if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) {
328+
rwp->params.progress_callback = [](struct whisper_context *ctx, struct whisper_state * /*state*/, int progress_cur, void *user_data) {
329+
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
330+
const VALUE progress = INT2NUM(progress_cur);
331+
// Currently, doesn't support state because
332+
// those require to resolve GC-related problems.
333+
if (!NIL_P(container->callback)) {
334+
rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data);
335+
}
336+
const long callbacks_len = RARRAY_LEN(container->callbacks);
337+
if (0 == callbacks_len) {
338+
return;
339+
}
340+
for (int j = 0; j < callbacks_len; j++) {
341+
VALUE cb = rb_ary_entry(container->callbacks, j);
342+
rb_funcall(cb, id_call, 1, progress);
343+
}
344+
};
345+
rwp->progress_callback_container->context = &self;
346+
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
347+
}
348+
349+
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
350+
rwp->params.abort_callback = [](void * user_data) {
351+
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
352+
if (!NIL_P(container->callback)) {
353+
VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data);
354+
if (!NIL_P(result) && Qfalse != result) {
355+
return true;
356+
}
357+
}
358+
const long callbacks_len = RARRAY_LEN(container->callbacks);
359+
if (0 == callbacks_len) {
360+
return false;
361+
}
362+
for (int j = 0; j < callbacks_len; j++) {
363+
VALUE cb = rb_ary_entry(container->callbacks, j);
364+
VALUE result = rb_funcall(cb, id_call, 1, container->user_data);
365+
if (!NIL_P(result) && Qfalse != result) {
366+
return true;
367+
}
368+
}
369+
return false;
370+
};
371+
rwp->abort_callback_container->context = &self;
372+
rwp->params.abort_callback_user_data = rwp->abort_callback_container;
373+
}
374+
319375
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
320376
fprintf(stderr, "failed to process audio\n");
321377
return self;
@@ -895,6 +951,62 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self,
895951
rwp->new_segment_callback_container->user_data = value;
896952
return value;
897953
}
954+
/*
955+
* Sets progress callback, called on each progress update.
956+
*
957+
* params.new_segment_callback = ->(context, _, n_new, user_data) {
958+
* # ...
959+
* }
960+
*
961+
* call-seq:
962+
* progress_callback = callback -> callback
963+
*/
964+
static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) {
965+
ruby_whisper_params *rwp;
966+
Data_Get_Struct(self, ruby_whisper_params, rwp);
967+
rwp->progress_callback_container->callback = value;
968+
return value;
969+
}
970+
/*
971+
* Sets user data passed to the last argument of progress callback.
972+
*
973+
* call-seq:
974+
* progress_callback_user_data = user_data -> use_data
975+
*/
976+
static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) {
977+
ruby_whisper_params *rwp;
978+
Data_Get_Struct(self, ruby_whisper_params, rwp);
979+
rwp->progress_callback_container->user_data = value;
980+
return value;
981+
}
982+
/*
983+
* Sets abort callback, called to check if the process should be aborted.
984+
*
985+
* params.abort_callback = ->(user_data) {
986+
* # ...
987+
* }
988+
*
989+
* call-seq:
990+
* abort_callback = callback -> callback
991+
*/
992+
static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) {
993+
ruby_whisper_params *rwp;
994+
Data_Get_Struct(self, ruby_whisper_params, rwp);
995+
rwp->abort_callback_container->callback = value;
996+
return value;
997+
}
998+
/*
999+
* Sets user data passed to the last argument of abort callback.
1000+
*
1001+
* call-seq:
1002+
* abort_callback_user_data = user_data -> use_data
1003+
*/
1004+
static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) {
1005+
ruby_whisper_params *rwp;
1006+
Data_Get_Struct(self, ruby_whisper_params, rwp);
1007+
rwp->abort_callback_container->user_data = value;
1008+
return value;
1009+
}
8981010

8991011
// High level API
9001012

@@ -977,6 +1089,46 @@ static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
9771089
return Qnil;
9781090
}
9791091

1092+
/*
1093+
* Hook called on progress update. Yields each progress Integer between 0 and 100.
1094+
*
1095+
* whisper.on_progress do |progress|
1096+
* # ...
1097+
* end
1098+
*
1099+
* call-seq:
1100+
* on_progress {|progress| ... }
1101+
*/
1102+
static VALUE ruby_whisper_params_on_progress(VALUE self) {
1103+
ruby_whisper_params *rws;
1104+
Data_Get_Struct(self, ruby_whisper_params, rws);
1105+
const VALUE blk = rb_block_proc();
1106+
rb_ary_push(rws->progress_callback_container->callbacks, blk);
1107+
return Qnil;
1108+
}
1109+
1110+
/*
1111+
* Call block to determine whether abort or not. Return +true+ when you want to abort.
1112+
*
1113+
* params.abort_on do
1114+
* if some_condition
1115+
* true # abort
1116+
* else
1117+
* false # continue
1118+
* end
1119+
* end
1120+
*
1121+
* call-seq:
1122+
* abort_on { ... }
1123+
*/
1124+
static VALUE ruby_whisper_params_abort_on(VALUE self) {
1125+
ruby_whisper_params *rws;
1126+
Data_Get_Struct(self, ruby_whisper_params, rws);
1127+
const VALUE blk = rb_block_proc();
1128+
rb_ary_push(rws->abort_callback_container->callbacks, blk);
1129+
return Qnil;
1130+
}
1131+
9801132
/*
9811133
* Start time in milliseconds.
9821134
*
@@ -1115,13 +1267,19 @@ void Init_whisper() {
11151267

11161268
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
11171269
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
1270+
rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1);
1271+
rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1);
1272+
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
1273+
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
11181274

11191275
// High leve
11201276
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
11211277

11221278
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
11231279
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
11241280
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
1281+
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1282+
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
11251283
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
11261284
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
11271285
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);

bindings/ruby/ext/ruby_whisper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ typedef struct {
1818
struct whisper_full_params params;
1919
bool diarize;
2020
ruby_whisper_callback_container *new_segment_callback_container;
21+
ruby_whisper_callback_container *progress_callback_container;
22+
ruby_whisper_callback_container *abort_callback_container;
2123
} ruby_whisper_params;
2224

2325
#endif

0 commit comments

Comments
 (0)