From 4c8651aa10d85d8525d2035ed6d439d84df06bfd Mon Sep 17 00:00:00 2001 From: Amanda Der Bedrosian Date: Thu, 27 Jun 2024 00:20:27 -0700 Subject: [PATCH 1/3] Implementing Encoder Begin Callback for golang binding Adding in EncoderBeginCallback to the Context's Process callback. This optional callback function returns false if computation should be aborted. --- bindings/go/README.md | 2 +- bindings/go/examples/go-whisper/process.go | 2 +- bindings/go/pkg/whisper/context.go | 35 ++++++++++++---------- bindings/go/pkg/whisper/interface.go | 8 +++-- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/bindings/go/README.md b/bindings/go/README.md index 1968cfd2470..245c10d1658 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -31,7 +31,7 @@ func main() { if err != nil { panic(err) } - if err := context.Process(samples, nil, nil); err != nil { + if err := context.Process(samples, nil, nil, nil); err != nil { return err } diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go index 71e52f01000..833947e843c 100644 --- a/bindings/go/examples/go-whisper/process.go +++ b/bindings/go/examples/go-whisper/process.go @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error { // Process the data fmt.Fprintf(flags.Output(), " ...processing %q\n", path) context.ResetTimings() - if err := context.Process(data, cb, nil); err != nil { + if err := context.Process(data, nil, cb, nil); err != nil { return err } diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index ead92648f3e..87e27b8e763 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -163,6 +163,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f // Process new sample data and return any errors func (context *context) Process( data []float32, + callEncoderBegin EncoderBeginCallback, callNewSegment SegmentCallback, callProgress ProgressCallback, ) error { @@ -177,7 +178,20 @@ func (context *context) Process( // We don't do parallel processing at the moment processors := 0 if processors > 1 { - if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { + if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin, + func(new int) { + if callNewSegment != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + callNewSegment(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err + } + } else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin, + func(new int) { if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new @@ -185,22 +199,11 @@ func (context *context) Process( callNewSegment(toSegment(context.model.ctx, i)) } } - }); err != nil { - return err - } - } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { - if callNewSegment != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() - s0 := num_segments - new - for i := s0; i < num_segments; i++ { - callNewSegment(toSegment(context.model.ctx, i)) + }, func(progress int) { + if callProgress != nil { + callProgress(progress) } - } - }, func(progress int) { - if callProgress != nil { - callProgress(progress) - } - }); err != nil { + }); err != nil { return err } diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index b430e7ce853..5c7554c7bdb 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -16,6 +16,10 @@ type SegmentCallback func(Segment) // processing. It is called during the Process function type ProgressCallback func(int) +// EncoderBeginCallback is the callback function for checking if we want to +// continue processing. It is called during the Process function +type EncoderBeginCallback func() bool + // Model is the interface to a whisper model. Create a new model with the // function whisper.New(string) type Model interface { @@ -31,7 +35,7 @@ type Model interface { Languages() []string } -// Context is the speach recognition context. +// Context is the speech recognition context. type Context interface { SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. SetTranslate(bool) // Set translate flag @@ -53,7 +57,7 @@ type Context interface { // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the // callback function during processing. - Process([]float32, SegmentCallback, ProgressCallback) error + Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error // After process is called, return segments until the end of the stream // is reached, when io.EOF is returned. From cdce0ff1573d5f24f11281f7788f7de78db5561d Mon Sep 17 00:00:00 2001 From: Amanda Der Bedrosian Date: Fri, 20 Sep 2024 19:29:12 -0700 Subject: [PATCH 2/3] Updating default parameters to match C++ example binary Updating the default parameters to match the C++ binary's parameters. This includes: - updating the strategy to SAMPLING_BEAM_SEARCH - greedy.best_of (previously -1) - thold_pt (previously .01) - thold_ptsum (previously .01) --- bindings/go/pkg/whisper/model.go | 2 +- bindings/go/whisper.go | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 68a150223c7..f3631c0e990 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -87,7 +87,7 @@ func (model *model) NewContext() (Context, error) { } // Create new context - params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY) + params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_BEAM_SEARCH) params.SetTranslate(false) params.SetPrintSpecial(false) params.SetPrintProgress(false) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 87da83f0f10..c07085f75f8 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -296,7 +296,13 @@ func Whisper_print_system_info() string { // Return default parameters for a strategy func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params { // Get default parameters - return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy))) + p := Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy))) + + p.greedy.best_of = 5 + p.thold_pt = 0 + p.thold_ptsum = 0 + + return p } // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text From 8d60666cbc5bad8670116542fb6302b3145af193 Mon Sep 17 00:00:00 2001 From: Amanda Der Bedrosian Date: Thu, 26 Sep 2024 11:20:12 -0700 Subject: [PATCH 3/3] Adding in a way to retrieve the detected language for golang bindings Adding in a function, GetDetectedLanguage, which will retrieve the detected language, if available. --- bindings/go/pkg/whisper/context.go | 4 ++++ bindings/go/pkg/whisper/interface.go | 10 ++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 87e27b8e763..d1b6d120066 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -289,6 +289,10 @@ func (context *context) IsLANG(t Token, lang string) bool { } } +func (context *context) GetDetectedLanguage() string { + return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id()) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 5c7554c7bdb..657f1c1a08a 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -37,10 +37,12 @@ type Model interface { // Context is the speech recognition context. type Context interface { - SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. - SetTranslate(bool) // Set translate flag - IsMultilingual() bool // Return true if the model is multilingual. - Language() string // Get language + SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. + SetTranslate(bool) // Set translate flag + IsMultilingual() bool // Return true if the model is multilingual. + Language() string // Get language + GetDetectedLanguage() string // Get auto detected language + SetOffset(time.Duration) // Set offset SetDuration(time.Duration) // Set duration