Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil, nil); err != nil {
if err := context.Process(samples, nil, nil, nil); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion bindings/go/examples/go-whisper/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb, nil); err != nil {
if err := context.Process(data, nil, cb, nil); err != nil {
return err
}

Expand Down
35 changes: 19 additions & 16 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
Expand All @@ -177,30 +178,32 @@ func (context *context) Process(
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
}); err != nil {
return err
}

Expand Down
8 changes: 6 additions & 2 deletions bindings/go/pkg/whisper/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
// processing. It is called during the Process function
type ProgressCallback func(int)

// EncoderBeginCallback is the callback function for checking if we want to
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool

// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
Expand All @@ -31,7 +35,7 @@ type Model interface {
Languages() []string
}

// Context is the speach recognition context.
// Context is the speech recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
Expand All @@ -53,7 +57,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback, ProgressCallback) error
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error

// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
Expand Down
2 changes: 1 addition & 1 deletion bindings/go/pkg/whisper/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (model *model) NewContext() (Context, error) {
}

// Create new context
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_BEAM_SEARCH)
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
Expand Down
8 changes: 7 additions & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,13 @@ func Whisper_print_system_info() string {
// Return default parameters for a strategy
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
// Get default parameters
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
p := Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))

p.greedy.best_of = 5
p.thold_pt = 0
p.thold_ptsum = 0

return p
}

// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
Expand Down
Loading