Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil, nil); err != nil {
if err := context.Process(samples, nil, nil, nil); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion bindings/go/examples/go-whisper/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb, nil); err != nil {
if err := context.Process(data, nil, cb, nil); err != nil {
return err
}

Expand Down
35 changes: 19 additions & 16 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
Expand All @@ -203,30 +204,32 @@ func (context *context) Process(
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
}); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion bindings/go/pkg/whisper/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ func TestProcess(t *testing.T) {
context, err := model.NewContext()
assert.NoError(err)

err = context.Process(data, nil, nil)
err = context.Process(data, nil, nil, nil)
assert.NoError(err)
}
8 changes: 6 additions & 2 deletions bindings/go/pkg/whisper/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
// processing. It is called during the Process function
type ProgressCallback func(int)

// EncoderBeginCallback is the callback function for checking if we want to
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool

// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
Expand All @@ -31,7 +35,7 @@ type Model interface {
Languages() []string
}

// Context is the speach recognition context.
// Context is the speech recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
Expand All @@ -58,7 +62,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback, ProgressCallback) error
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error

// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
Expand Down
Loading