Skip to content

Commit 97707fb

Browse files
committed
Adding in NewContextWithStrategy for go binding
Adding in NewContextWithStrategy, a function to allow you to specify the strategy when initializing the context. NewContext will use this "helper" function so that users don't need to think about what strategy they're specifying by default. I'm also doing it this way (instead of a function that sets the strategy) because the basis of what parameters we populate the context with is determined by the strategy we are using, so it needs to be done in the beginning of creating the context.
1 parent 996581c commit 97707fb

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

bindings/go/pkg/whisper/interface.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type Model interface {
2727

2828
// Return a new speech-to-text context.
2929
NewContext() (Context, error)
30+
NewContextWithStrategy(SamplingStrategy) (Context, error)
3031

3132
// Return true if the model is multilingual.
3233
IsMultilingual() bool

bindings/go/pkg/whisper/model.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ type model struct {
2020
// Make sure model adheres to the interface
2121
var _ Model = (*model)(nil)
2222

23+
type SamplingStrategy whisper.SamplingStrategy
24+
25+
const (
26+
SAMPLING_GREEDY SamplingStrategy = (SamplingStrategy)(whisper.SAMPLING_GREEDY)
27+
SAMPLING_BEAM_SEARCH SamplingStrategy = (SamplingStrategy)(whisper.SAMPLING_BEAM_SEARCH)
28+
)
29+
2330
///////////////////////////////////////////////////////////////////////////////
2431
// LIFECYCLE
2532

@@ -82,12 +89,17 @@ func (model *model) Languages() []string {
8289
}
8390

8491
func (model *model) NewContext() (Context, error) {
92+
// By default, specify the greedy strategy
93+
return model.NewContextWithStrategy(SAMPLING_GREEDY)
94+
}
95+
96+
func (model *model) NewContextWithStrategy(strategy SamplingStrategy) (Context, error) {
8597
if model.ctx == nil {
8698
return nil, ErrInternalAppError
8799
}
88100

89101
// Create new context
90-
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
102+
params := model.ctx.Whisper_full_default_params((whisper.SamplingStrategy)(strategy))
91103
params.SetTranslate(false)
92104
params.SetPrintSpecial(false)
93105
params.SetPrintProgress(false)

0 commit comments

Comments
 (0)