@@ -107,10 +107,16 @@ void rb_whisper_free(ruby_whisper *rw) {
107107 free (rw);
108108}
109109
110+ void rb_whisper_callbcack_container_mark (ruby_whisper_callback_container *rwc) {
111+ rb_gc_mark (rwc->user_data );
112+ rb_gc_mark (rwc->callback );
113+ rb_gc_mark (rwc->callbacks );
114+ }
115+
110116void rb_whisper_params_mark (ruby_whisper_params *rwp) {
111- rb_gc_mark (rwp->new_segment_callback_container -> user_data );
112- rb_gc_mark (rwp->new_segment_callback_container -> callback );
113- rb_gc_mark (rwp->new_segment_callback_container -> callbacks );
117+ rb_whisper_callbcack_container_mark (rwp->new_segment_callback_container );
118+ rb_whisper_callbcack_container_mark (rwp->progress_callback_container );
119+ rb_whisper_callbcack_container_mark (rwp->abort_callback_container );
114120}
115121
116122void rb_whisper_params_free (ruby_whisper_params *rwp) {
@@ -141,6 +147,8 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
141147 rwp = ALLOC (ruby_whisper_params);
142148 rwp->params = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
143149 rwp->new_segment_callback_container = rb_whisper_callback_container_allocate ();
150+ rwp->progress_callback_container = rb_whisper_callback_container_allocate ();
151+ rwp->abort_callback_container = rb_whisper_callback_container_allocate ();
144152 return Data_Wrap_Struct (klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
145153}
146154
@@ -316,6 +324,54 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
316324 rwp->params .new_segment_callback_user_data = rwp->new_segment_callback_container ;
317325 }
318326
327+ if (!NIL_P (rwp->progress_callback_container ->callback ) || 0 != RARRAY_LEN (rwp->progress_callback_container ->callbacks )) {
328+ rwp->params .progress_callback = [](struct whisper_context *ctx, struct whisper_state * /* state*/ , int progress_cur, void *user_data) {
329+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
330+ const VALUE progress = INT2NUM (progress_cur);
331+ // Currently, doesn't support state because
332+ // those require to resolve GC-related problems.
333+ if (!NIL_P (container->callback )) {
334+ rb_funcall (container->callback , id_call, 4 , *container->context , Qnil, progress, container->user_data );
335+ }
336+ const long callbacks_len = RARRAY_LEN (container->callbacks );
337+ if (0 == callbacks_len) {
338+ return ;
339+ }
340+ for (int j = 0 ; j < callbacks_len; j++) {
341+ VALUE cb = rb_ary_entry (container->callbacks , j);
342+ rb_funcall (cb, id_call, 1 , progress);
343+ }
344+ };
345+ rwp->progress_callback_container ->context = &self;
346+ rwp->params .progress_callback_user_data = rwp->progress_callback_container ;
347+ }
348+
349+ if (!NIL_P (rwp->abort_callback_container ->callback ) || 0 != RARRAY_LEN (rwp->abort_callback_container ->callbacks )) {
350+ rwp->params .abort_callback = [](void * user_data) {
351+ const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
352+ if (!NIL_P (container->callback )) {
353+ VALUE result = rb_funcall (container->callback , id_call, 1 , container->user_data );
354+ if (!NIL_P (result) && Qfalse != result) {
355+ return true ;
356+ }
357+ }
358+ const long callbacks_len = RARRAY_LEN (container->callbacks );
359+ if (0 == callbacks_len) {
360+ return false ;
361+ }
362+ for (int j = 0 ; j < callbacks_len; j++) {
363+ VALUE cb = rb_ary_entry (container->callbacks , j);
364+ VALUE result = rb_funcall (cb, id_call, 1 , container->user_data );
365+ if (!NIL_P (result) && Qfalse != result) {
366+ return true ;
367+ }
368+ }
369+ return false ;
370+ };
371+ rwp->abort_callback_container ->context = &self;
372+ rwp->params .abort_callback_user_data = rwp->abort_callback_container ;
373+ }
374+
319375 if (whisper_full_parallel (rw->context , rwp->params , pcmf32.data (), pcmf32.size (), 1 ) != 0 ) {
320376 fprintf (stderr, " failed to process audio\n " );
321377 return self;
@@ -895,6 +951,62 @@ static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self,
895951 rwp->new_segment_callback_container ->user_data = value;
896952 return value;
897953}
954+ /*
955+ * Sets progress callback, called on each progress update.
956+ *
957+ * params.new_segment_callback = ->(context, _, n_new, user_data) {
958+ * # ...
959+ * }
960+ *
961+ * call-seq:
962+ * progress_callback = callback -> callback
963+ */
964+ static VALUE ruby_whisper_params_set_progress_callback (VALUE self, VALUE value) {
965+ ruby_whisper_params *rwp;
966+ Data_Get_Struct (self, ruby_whisper_params, rwp);
967+ rwp->progress_callback_container ->callback = value;
968+ return value;
969+ }
970+ /*
971+ * Sets user data passed to the last argument of progress callback.
972+ *
973+ * call-seq:
974+ * progress_callback_user_data = user_data -> use_data
975+ */
976+ static VALUE ruby_whisper_params_set_progress_callback_user_data (VALUE self, VALUE value) {
977+ ruby_whisper_params *rwp;
978+ Data_Get_Struct (self, ruby_whisper_params, rwp);
979+ rwp->progress_callback_container ->user_data = value;
980+ return value;
981+ }
982+ /*
983+ * Sets abort callback, called to check if the process should be aborted.
984+ *
985+ * params.abort_callback = ->(user_data) {
986+ * # ...
987+ * }
988+ *
989+ * call-seq:
990+ * abort_callback = callback -> callback
991+ */
992+ static VALUE ruby_whisper_params_set_abort_callback (VALUE self, VALUE value) {
993+ ruby_whisper_params *rwp;
994+ Data_Get_Struct (self, ruby_whisper_params, rwp);
995+ rwp->abort_callback_container ->callback = value;
996+ return value;
997+ }
998+ /*
999+ * Sets user data passed to the last argument of abort callback.
1000+ *
1001+ * call-seq:
1002+ * abort_callback_user_data = user_data -> use_data
1003+ */
1004+ static VALUE ruby_whisper_params_set_abort_callback_user_data (VALUE self, VALUE value) {
1005+ ruby_whisper_params *rwp;
1006+ Data_Get_Struct (self, ruby_whisper_params, rwp);
1007+ rwp->abort_callback_container ->user_data = value;
1008+ return value;
1009+ }
8981010
8991011// High level API
9001012
@@ -977,6 +1089,46 @@ static VALUE ruby_whisper_params_on_new_segment(VALUE self) {
9771089 return Qnil;
9781090}
9791091
1092+ /*
1093+ * Hook called on progress update. Yields each progress Integer between 0 and 100.
1094+ *
1095+ * whisper.on_progress do |progress|
1096+ * # ...
1097+ * end
1098+ *
1099+ * call-seq:
1100+ * on_progress {|progress| ... }
1101+ */
1102+ static VALUE ruby_whisper_params_on_progress (VALUE self) {
1103+ ruby_whisper_params *rws;
1104+ Data_Get_Struct (self, ruby_whisper_params, rws);
1105+ const VALUE blk = rb_block_proc ();
1106+ rb_ary_push (rws->progress_callback_container ->callbacks , blk);
1107+ return Qnil;
1108+ }
1109+
1110+ /*
1111+ * Call block to determine whether abort or not. Return +true+ when you want to abort.
1112+ *
1113+ * params.abort_on do
1114+ * if some_condition
1115+ * true # abort
1116+ * else
1117+ * false # continue
1118+ * end
1119+ * end
1120+ *
1121+ * call-seq:
1122+ * abort_on { ... }
1123+ */
1124+ static VALUE ruby_whisper_params_abort_on (VALUE self) {
1125+ ruby_whisper_params *rws;
1126+ Data_Get_Struct (self, ruby_whisper_params, rws);
1127+ const VALUE blk = rb_block_proc ();
1128+ rb_ary_push (rws->abort_callback_container ->callbacks , blk);
1129+ return Qnil;
1130+ }
1131+
9801132/*
9811133 * Start time in milliseconds.
9821134 *
@@ -1115,13 +1267,19 @@ void Init_whisper() {
11151267
11161268 rb_define_method (cParams, " new_segment_callback=" , ruby_whisper_params_set_new_segment_callback, 1 );
11171269 rb_define_method (cParams, " new_segment_callback_user_data=" , ruby_whisper_params_set_new_segment_callback_user_data, 1 );
1270+ rb_define_method (cParams, " progress_callback=" , ruby_whisper_params_set_progress_callback, 1 );
1271+ rb_define_method (cParams, " progress_callback_user_data=" , ruby_whisper_params_set_progress_callback_user_data, 1 );
1272+ rb_define_method (cParams, " abort_callback=" , ruby_whisper_params_set_abort_callback, 1 );
1273+ rb_define_method (cParams, " abort_callback_user_data=" , ruby_whisper_params_set_abort_callback_user_data, 1 );
11181274
11191275 // High leve
11201276 cSegment = rb_define_class_under (mWhisper , " Segment" , rb_cObject);
11211277
11221278 rb_define_alloc_func (cSegment, ruby_whisper_segment_allocate);
11231279 rb_define_method (cContext, " each_segment" , ruby_whisper_each_segment, 0 );
11241280 rb_define_method (cParams, " on_new_segment" , ruby_whisper_params_on_new_segment, 0 );
1281+ rb_define_method (cParams, " on_progress" , ruby_whisper_params_on_progress, 0 );
1282+ rb_define_method (cParams, " abort_on" , ruby_whisper_params_abort_on, 0 );
11251283 rb_define_method (cSegment, " start_time" , ruby_whisper_segment_get_start_time, 0 );
11261284 rb_define_method (cSegment, " end_time" , ruby_whisper_segment_get_end_time, 0 );
11271285 rb_define_method (cSegment, " speaker_next_turn?" , ruby_whisper_segment_get_speaker_turn_next, 0 );
0 commit comments