diff --git a/openfeature/multi/README.md b/openfeature/multi/README.md index 09deba12..75e3359f 100644 --- a/openfeature/multi/README.md +++ b/openfeature/multi/README.md @@ -32,12 +32,13 @@ import ( "github.com/open-feature/go-sdk/openfeature/memprovider" ) -providers := make(multi.ProviderMap) -providers["providerA"] = memprovider.NewInMemoryProvider(map[string]memprovider.InMemoryFlag{}) -providers["providerB"] = myCustomProvider -mprovider, err := multi.NewProvider(providers, multi.StrategyFirstMatch) +mprovider, err := multi.NewProvider( + multi.StrategyFirstMatch, + multi.WithProvider("providerA", memprovider.NewInMemoryProvider(/*...*/)), + multi.WithProvider("providerB", myCustomProvider), +) if err != nil { - return err + return err } openfeature.SetNamedProviderAndWait("multiprovider", mprovider) @@ -101,7 +102,7 @@ type StrategyConstructor func(providers []*NamedProvider) StrategyFn[FlagTypes] Build your strategy to wrap around the slice of providers ```go -option := multi.WithCustomStrategy(func(providers []*NamedProvider) StrategyFn[FlagTypes] { +option := multi.WithCustomStrategy(func(providers []NamedProvider) StrategyFn[FlagTypes] { return func[T FlagTypes](ctx context.Context, flag string, defaultValue T, flatCtx openfeature.FlattenedContext) openfeature.GenericResolutionDetail[T] { // implementation // ... @@ -140,12 +141,12 @@ essentially a factory that allows the `StrategyFn` to wrap around a slice of `Na Allows for setting global hooks for the multi-provider. These are `openfeature.Hook` implementations that affect **all** internal `FeatureProvider` instances. -### `WithProviderHooks` +### `WithProvider` -Allows for setting `openfeature.Hook` implementations on a specific named `FeatureProvider` within the multi-provider. -This should only be used when hooks need to be attached to a `FeatureProvider` instance that does not implement that functionality. -Using a provider name that is not known will cause an error to be returned during the creation time. This option can be -used multiple times using unique provider names. +Allows for registering a specific `FeatureProvider` instance under a unique provider name. Optional `openfeature.Hook` +implementations may also be provided, which will execute only for this specific provider. This option can be used multiple +times with unique provider names to register multiple providers. +The order in which `WithProvider` options are provided determines the order in which the providers are registered and evaluated. ## `StrategyComparision` specific options diff --git a/openfeature/multi/comparison_strategy.go b/openfeature/multi/comparison_strategy.go index 92d7068d..bd0fe15c 100644 --- a/openfeature/multi/comparison_strategy.go +++ b/openfeature/multi/comparison_strategy.go @@ -24,7 +24,7 @@ type Comparator func(values []any) bool // can be passed as long as ObjectEvaluation is never called with objects that are not comparable. The custom [Comparator] // will only be used for [of.FeatureProvider.ObjectEvaluation] if set. If [of.FeatureProvider.ObjectEvaluation] is // called without setting a [Comparator], and the returned object(s) are not comparable, then an error will occur. -func newComparisonStrategy(providers []*NamedProvider, fallbackProvider of.FeatureProvider, comparator Comparator) StrategyFn[FlagTypes] { +func newComparisonStrategy(providers []NamedProvider, fallbackProvider of.FeatureProvider, comparator Comparator) StrategyFn[FlagTypes] { return evaluateComparison[FlagTypes](providers, fallbackProvider, comparator) } @@ -81,7 +81,7 @@ func comparisonResolutionError(metadata of.FlagMetadata) of.ResolutionError { return of.NewGeneralResolutionError("comparison failure") } -func evaluateComparison[T FlagTypes](providers []*NamedProvider, fallbackProvider of.FeatureProvider, comparator Comparator) StrategyFn[T] { +func evaluateComparison[T FlagTypes](providers []NamedProvider, fallbackProvider of.FeatureProvider, comparator Comparator) StrategyFn[T] { return func(ctx context.Context, flag string, defaultValue T, evalCtx of.FlattenedContext) of.GenericResolutionDetail[T] { if comparator == nil { comparator = defaultComparator @@ -103,7 +103,7 @@ func evaluateComparison[T FlagTypes](providers []*NamedProvider, fallbackProvide // Short circuit if there's only one provider as no comparison nor workers are needed if len(providers) == 1 { result := Evaluate(ctx, providers[0], flag, defaultValue, evalCtx) - metadata := setFlagMetadata(StrategyComparison, providers[0].Name, make(of.FlagMetadata)) + metadata := setFlagMetadata(StrategyComparison, providers[0].Name(), make(of.FlagMetadata)) metadata[MetadataFallbackUsed] = false result.FlagMetadata = mergeFlagMeta(result.FlagMetadata, metadata) return result @@ -124,13 +124,13 @@ func evaluateComparison[T FlagTypes](providers []*NamedProvider, fallbackProvide notFound := result.ResolutionDetail().ErrorCode == of.FlagNotFoundCode if !notFound && result.Error() != nil { return &ProviderError{ - ProviderName: closedProvider.Name, - Err: result.Error(), + ProviderName: closedProvider.Name(), + err: result.Error(), } } if !notFound { resultChan <- &namedResult{ - name: closedProvider.Name, + name: closedProvider.Name(), res: &result, } } else { @@ -225,7 +225,7 @@ func evaluateComparison[T FlagTypes](providers []*NamedProvider, fallbackProvide if fallbackProvider != nil { fallbackResult := Evaluate( ctx, - &NamedProvider{Name: "fallback", FeatureProvider: fallbackProvider}, + &namedProvider{name: "fallback", FeatureProvider: fallbackProvider}, flag, defaultValue, evalCtx, diff --git a/openfeature/multi/comparison_strategy_test.go b/openfeature/multi/comparison_strategy_test.go index f6a76c8a..a11b7091 100644 --- a/openfeature/multi/comparison_strategy_test.go +++ b/openfeature/multi/comparison_strategy_test.go @@ -104,9 +104,9 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { fallback := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider", FeatureProvider: provider, }, }, fallback, nil) @@ -128,13 +128,13 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -159,17 +159,17 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider3 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, - { - Name: "test-provider3", + &namedProvider{ + name: "test-provider3", FeatureProvider: provider3, }, }, fallback, nil) @@ -193,17 +193,17 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider3 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, - { - Name: "test-provider3", + &namedProvider{ + name: "test-provider3", FeatureProvider: provider3, }, }, fallback, nil) @@ -231,21 +231,21 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider4 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider4, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, - { - Name: "test-provider3", + &namedProvider{ + name: "test-provider3", FeatureProvider: provider3, }, - { - Name: "test-provider4", + &namedProvider{ + name: "test-provider4", FeatureProvider: provider4, }, }, fallback, nil) @@ -271,17 +271,17 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider3 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider3, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, - { - Name: "test-provider3", + &namedProvider{ + name: "test-provider3", FeatureProvider: provider3, }, }, fallback, nil) @@ -305,13 +305,13 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, defaultVal, true, TestErrorNotFound, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -340,21 +340,21 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider4 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider4, defaultVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, - { - Name: "test-provider3", + &namedProvider{ + name: "test-provider3", FeatureProvider: provider3, }, - { - Name: "test-provider4", + &namedProvider{ + name: "test-provider4", FeatureProvider: provider4, }, }, fallback, nil) @@ -378,13 +378,13 @@ func Test_ComparisonStrategy_Evaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, defaultVal, true, TestErrorError, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -497,13 +497,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, testCase.successValue, true, TestErrorNone, true) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -528,13 +528,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, tc.defaultValue, true, TestErrorNone, true) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -558,13 +558,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, successVal, true, TestErrorNone, true) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -589,13 +589,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, defaultVal, true, TestErrorNone, true) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) @@ -619,13 +619,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, successVal, true, TestErrorNone, true) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, func(val []any) bool { @@ -651,13 +651,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, successVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, func(val []any) bool { @@ -686,13 +686,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, defaultVal, true, TestErrorNone, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, func(val []any) bool { @@ -718,13 +718,13 @@ func Test_ComparisonStrategy_ObjectEvaluation(t *testing.T) { configureComparisonProvider(provider1, successVal, true, TestErrorNone, false) provider2 := of.NewMockFeatureProvider(ctrl) configureComparisonProvider(provider2, successVal, true, TestErrorError, false) - strategy := newComparisonStrategy([]*NamedProvider{ - { - Name: "test-provider1", + strategy := newComparisonStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider1", FeatureProvider: provider1, }, - { - Name: "test-provider2", + &namedProvider{ + name: "test-provider2", FeatureProvider: provider2, }, }, fallback, nil) diff --git a/openfeature/multi/errors.go b/openfeature/multi/errors.go index b3150c86..7fdd325a 100644 --- a/openfeature/multi/errors.go +++ b/openfeature/multi/errors.go @@ -3,13 +3,14 @@ package multi import ( "errors" "fmt" + "sync" ) type ( // ProviderError is an error wrapper that includes the provider name. ProviderError struct { - // Err is the original error that was returned from a provider - Err error + // err is the original error that was returned from a provider + err error // ProviderName is the name of the provider that returned the included error ProviderName string } @@ -24,8 +25,17 @@ var ( _ error = (AggregateError)(nil) ) +// Error implements the error interface for ProviderError. func (e *ProviderError) Error() string { - return fmt.Sprintf("Provider %s: %s", e.ProviderName, e.Err.Error()) + if e.err == nil { + return fmt.Sprintf("Provider %s: ", e.ProviderName) + } + return fmt.Sprintf("Provider %s: %s", e.ProviderName, e.err.Error()) +} + +// Unwrap allows access to the original error, if any. +func (e *ProviderError) Unwrap() error { + return e.err } // NewAggregateError creates a new AggregateError from a slice of [ProviderError] instances @@ -44,3 +54,30 @@ func (ae AggregateError) Error() string { } return errors.Join(errs...).Error() } + +// multiErrGroup collects all errors from concurrent goroutines. +type multiErrGroup struct { + wg sync.WaitGroup + mu sync.Mutex + errors []error +} + +// Go starts a function in a goroutine. +func (g *multiErrGroup) Go(fn func() error) { + g.wg.Add(1) + go func() { + defer g.wg.Done() + if err := fn(); err != nil { + g.mu.Lock() + g.errors = append(g.errors, err) + g.mu.Unlock() + } + }() +} + +// Wait waits for all goroutines to complete. +// Returns a combined error or nil if none. +func (g *multiErrGroup) Wait() error { + g.wg.Wait() + return errors.Join(g.errors...) +} diff --git a/openfeature/multi/errors_test.go b/openfeature/multi/errors_test.go index 2bc75933..3e56a25c 100644 --- a/openfeature/multi/errors_test.go +++ b/openfeature/multi/errors_test.go @@ -1,12 +1,36 @@ package multi import ( + "errors" "fmt" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func Test_ProviderError_Error(t *testing.T) { + t.Run("with nil err", func(t *testing.T) { + err := &ProviderError{ + ProviderName: "TestError", + } + assert.EqualError(t, err, "Provider TestError: ") + assert.Equal(t, "TestError", err.ProviderName) + }) + + t.Run("with custom error", func(t *testing.T) { + originalErr := errors.New("custom error message") + err := &ProviderError{ + ProviderName: "TestError", + err: originalErr, + } + assert.EqualError(t, err, "Provider TestError: custom error message") + assert.Equal(t, "TestError", err.ProviderName) + assert.ErrorIs(t, err, originalErr) + }) +} + func Test_AggregateError_Error(t *testing.T) { t.Run("empty error", func(t *testing.T) { err := NewAggregateError([]ProviderError{}) @@ -16,7 +40,7 @@ func Test_AggregateError_Error(t *testing.T) { t.Run("single error", func(t *testing.T) { err := NewAggregateError([]ProviderError{ { - Err: fmt.Errorf("test error"), + err: fmt.Errorf("test error"), ProviderName: "test-provider", }, }) @@ -27,11 +51,11 @@ func Test_AggregateError_Error(t *testing.T) { t.Run("multiple errors", func(t *testing.T) { err := NewAggregateError([]ProviderError{ { - Err: fmt.Errorf("test error"), + err: fmt.Errorf("test error"), ProviderName: "test-provider1", }, { - Err: fmt.Errorf("test error"), + err: fmt.Errorf("test error"), ProviderName: "test-provider2", }, }) @@ -39,3 +63,90 @@ func Test_AggregateError_Error(t *testing.T) { assert.Equal(t, "Provider test-provider1: test error\nProvider test-provider2: test error", err.Error()) }) } + +func Test_multiErrGroup(t *testing.T) { + tests := []struct { + name string + setup func(*multiErrGroup) + wantErrs []string + }{ + { + name: "no errors - all goroutines succeed", + setup: func(meg *multiErrGroup) { + for range 3 { + meg.Go(func() error { + time.Sleep(10 * time.Millisecond) + return nil + }) + } + }, + wantErrs: nil, + }, + { + name: "single error among successful goroutines", + setup: func(meg *multiErrGroup) { + meg.Go(func() error { + return errors.New("error 0") + }) + meg.Go(func() error { + return nil + }) + }, + wantErrs: []string{"error 0"}, + }, + { + name: "multiple errors collected", + setup: func(meg *multiErrGroup) { + meg.Go(func() error { + return errors.New("error 1") + }) + meg.Go(func() error { + return nil + }) + meg.Go(func() error { + return errors.New("error 2") + }) + meg.Go(func() error { + return errors.New("error 3") + }) + }, + wantErrs: []string{"error 1", "error 2", "error 3"}, + }, + { + name: "empty group returns no error", + setup: func(meg *multiErrGroup) {}, + wantErrs: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var meg multiErrGroup + tt.setup(&meg) + + err := meg.Wait() + + if tt.wantErrs != nil { + require.Error(t, err) + for _, errMsg := range tt.wantErrs { + assert.ErrorContains(t, err, errMsg) + } + } + }) + } +} + +func Test_multiErrGroup_MultipleWaits(t *testing.T) { + var meg multiErrGroup + + meg.Go(func() error { + return errors.New("test error") + }) + + err1 := meg.Wait() + require.Error(t, err1) + + err2 := meg.Wait() + require.Error(t, err2) + assert.Equal(t, err1.Error(), err2.Error()) +} diff --git a/openfeature/multi/first_match_strategy.go b/openfeature/multi/first_match_strategy.go index 88546a5e..430de0df 100644 --- a/openfeature/multi/first_match_strategy.go +++ b/openfeature/multi/first_match_strategy.go @@ -8,11 +8,11 @@ import ( // newFirstMatchStrategy returns a [StrategyFn] that returns the result of the first [of.FeatureProvider] whose response is // not [of.FlagNotFoundCode]. This is executed sequentially, and not in parallel. -func newFirstMatchStrategy(providers []*NamedProvider) StrategyFn[FlagTypes] { +func newFirstMatchStrategy(providers []NamedProvider) StrategyFn[FlagTypes] { return firstMatchStrategyFn[FlagTypes](providers) } -func firstMatchStrategyFn[T FlagTypes](providers []*NamedProvider) StrategyFn[T] { +func firstMatchStrategyFn[T FlagTypes](providers []NamedProvider) StrategyFn[T] { return func(ctx context.Context, flag string, defaultValue T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { for _, provider := range providers { resolution := Evaluate(ctx, provider, flag, defaultValue, flatCtx) @@ -30,7 +30,7 @@ func firstMatchStrategyFn[T FlagTypes](providers []*NamedProvider) StrategyFn[T] } // success! - resolution.FlagMetadata = setFlagMetadata(StrategyFirstMatch, provider.Name, resolution.FlagMetadata) + resolution.FlagMetadata = setFlagMetadata(StrategyFirstMatch, provider.Name(), resolution.FlagMetadata) return resolution } diff --git a/openfeature/multi/first_match_strategy_test.go b/openfeature/multi/first_match_strategy_test.go index 69c86828..614d0f19 100644 --- a/openfeature/multi/first_match_strategy_test.go +++ b/openfeature/multi/first_match_strategy_test.go @@ -29,10 +29,10 @@ func Test_FirstMatchStrategy_Evaluation(t *testing.T) { t.Run("Single Provider Match", func(t *testing.T) { mocks := createMockProviders(ctrl, 1) configureFirstMatchProviderMock(mocks[0], tt.successVal, TestErrorNone, "mock provider") - providers := make([]*NamedProvider, 0, 5) + providers := make([]NamedProvider, 0, 5) for i, m := range mocks { - providers = append(providers, &NamedProvider{ - Name: strconv.Itoa(i), + providers = append(providers, &namedProvider{ + name: strconv.Itoa(i), FeatureProvider: m, }) } @@ -40,16 +40,16 @@ func Test_FirstMatchStrategy_Evaluation(t *testing.T) { result := strategy(context.Background(), "test-string", tt.defaultVal, of.FlattenedContext{}) assert.Equal(t, tt.successVal, result.Value) assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) - assert.Equal(t, providers[0].Name, result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Equal(t, providers[0].Name(), result.FlagMetadata[MetadataSuccessfulProviderName]) }) t.Run("Default Resolution", func(t *testing.T) { mocks := createMockProviders(ctrl, 1) configureFirstMatchProviderMock(mocks[0], tt.defaultVal, TestErrorNotFound, "mock provider") - providers := make([]*NamedProvider, 0, 5) + providers := make([]NamedProvider, 0, 5) for i, m := range mocks { - providers = append(providers, &NamedProvider{ - Name: strconv.Itoa(i), + providers = append(providers, &namedProvider{ + name: strconv.Itoa(i), FeatureProvider: m, }) } @@ -66,10 +66,10 @@ func Test_FirstMatchStrategy_Evaluation(t *testing.T) { mocks := createMockProviders(ctrl, 5) configureFirstMatchProviderMock(mocks[0], tt.defaultVal, TestErrorNotFound, "mock provider 1") configureFirstMatchProviderMock(mocks[1], tt.successVal, TestErrorNone, "mock provider 2") - providers := make([]*NamedProvider, 0, 5) + providers := make([]NamedProvider, 0, 5) for i, m := range mocks { - providers = append(providers, &NamedProvider{ - Name: strconv.Itoa(i), + providers = append(providers, &namedProvider{ + name: strconv.Itoa(i), FeatureProvider: m, }) } @@ -78,16 +78,16 @@ func Test_FirstMatchStrategy_Evaluation(t *testing.T) { result := strategy(context.Background(), "test-flag", tt.defaultVal, of.FlattenedContext{}) assert.Equal(t, tt.successVal, result.Value) assert.Contains(t, result.FlagMetadata, MetadataSuccessfulProviderName) - assert.Equal(t, providers[1].Name, result.FlagMetadata[MetadataSuccessfulProviderName]) + assert.Equal(t, providers[1].Name(), result.FlagMetadata[MetadataSuccessfulProviderName]) }) t.Run("Evaluation stops after first error that is not a FLAG_NOT_FOUND error", func(t *testing.T) { mocks := createMockProviders(ctrl, 5) expectedErr := of.NewGeneralResolutionError("test error") - providers := make([]*NamedProvider, 0, 5) + providers := make([]NamedProvider, 0, 5) for i, m := range mocks { - providers = append(providers, &NamedProvider{ - Name: strconv.Itoa(i), + providers = append(providers, &namedProvider{ + name: strconv.Itoa(i), FeatureProvider: m, }) switch { diff --git a/openfeature/multi/first_success_strategy.go b/openfeature/multi/first_success_strategy.go index d30d73bb..fed09952 100644 --- a/openfeature/multi/first_success_strategy.go +++ b/openfeature/multi/first_success_strategy.go @@ -9,11 +9,11 @@ import ( // newFirstSuccessStrategy returns a [StrategyFn] that returns the result of the First [of.FeatureProvider] whose response // is not an error. This executed sequentially. -func newFirstSuccessStrategy(providers []*NamedProvider) StrategyFn[FlagTypes] { +func newFirstSuccessStrategy(providers []NamedProvider) StrategyFn[FlagTypes] { return firstSuccessStrategyFn[FlagTypes](providers) } -func firstSuccessStrategyFn[T FlagTypes](providers []*NamedProvider) StrategyFn[T] { +func firstSuccessStrategyFn[T FlagTypes](providers []NamedProvider) StrategyFn[T] { return func(ctx context.Context, flag string, defaultValue T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { resolutionErrors := make([]error, 0, len(providers)) for _, provider := range providers { @@ -22,7 +22,7 @@ func firstSuccessStrategyFn[T FlagTypes](providers []*NamedProvider) StrategyFn[ resolutionErrors = append(resolutionErrors, resolution.Error()) continue } - resolution.FlagMetadata = setFlagMetadata(StrategyFirstSuccess, provider.Name, resolution.FlagMetadata) + resolution.FlagMetadata = setFlagMetadata(StrategyFirstSuccess, provider.Name(), resolution.FlagMetadata) return resolution } return BuildDefaultResult(StrategyFirstSuccess, defaultValue, errors.Join(resolutionErrors...)) diff --git a/openfeature/multi/first_success_strategy_test.go b/openfeature/multi/first_success_strategy_test.go index 0c178895..f6e3e2e3 100644 --- a/openfeature/multi/first_success_strategy_test.go +++ b/openfeature/multi/first_success_strategy_test.go @@ -93,9 +93,9 @@ func Test_FirstSuccessStrategyEvaluation(t *testing.T) { provider := of.NewMockFeatureProvider(ctrl) configureFirstSuccessProvider(provider, tt.successVal, true, TestErrorNone) - strategy := newFirstSuccessStrategy([]*NamedProvider{ - { - Name: "test-provider", + strategy := newFirstSuccessStrategy([]NamedProvider{ + &namedProvider{ + name: "test-provider", FeatureProvider: provider, }, }) @@ -114,13 +114,13 @@ func Test_FirstSuccessStrategyEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureFirstSuccessProvider(provider2, tt.defaultVal, false, TestErrorError) - strategy := newFirstSuccessStrategy([]*NamedProvider{ - { - Name: "success-provider", + strategy := newFirstSuccessStrategy([]NamedProvider{ + &namedProvider{ + name: "success-provider", FeatureProvider: provider1, }, - { - Name: "failure-provider", + &namedProvider{ + name: "failure-provider", FeatureProvider: provider2, }, }) @@ -139,13 +139,13 @@ func Test_FirstSuccessStrategyEvaluation(t *testing.T) { provider2 := of.NewMockFeatureProvider(ctrl) configureFirstSuccessProvider(provider2, tt.defaultVal, false, TestErrorError) - strategy := newFirstSuccessStrategy([]*NamedProvider{ - { - Name: "success-provider", + strategy := newFirstSuccessStrategy([]NamedProvider{ + &namedProvider{ + name: "success-provider", FeatureProvider: provider1, }, - { - Name: "failure-provider", + &namedProvider{ + name: "failure-provider", FeatureProvider: provider2, }, }) @@ -166,17 +166,17 @@ func Test_FirstSuccessStrategyEvaluation(t *testing.T) { provider3 := of.NewMockFeatureProvider(ctrl) configureFirstSuccessProvider(provider3, tt.defaultVal, false, TestErrorError) - strategy := newFirstSuccessStrategy([]*NamedProvider{ - { - Name: "provider1", + strategy := newFirstSuccessStrategy([]NamedProvider{ + &namedProvider{ + name: "provider1", FeatureProvider: provider1, }, - { - Name: "provider2", + &namedProvider{ + name: "provider2", FeatureProvider: provider2, }, - { - Name: "provider3", + &namedProvider{ + name: "provider3", FeatureProvider: provider3, }, }) diff --git a/openfeature/multi/isolation.go b/openfeature/multi/isolation.go index 268af842..12382506 100644 --- a/openfeature/multi/isolation.go +++ b/openfeature/multi/isolation.go @@ -18,6 +18,7 @@ type ( hooks []of.Hook capturedContext of.HookContext capturedHints of.HookHints + name string } // eventHandlingHookIsolator is equivalent to hookIsolator, but also implements [of.EventHandler] @@ -28,22 +29,24 @@ type ( // Compile-time interface compliance checks var ( + _ NamedProvider = (*hookIsolator)(nil) _ of.FeatureProvider = (*hookIsolator)(nil) _ of.Hook = (*hookIsolator)(nil) _ of.EventHandler = (*eventHandlingHookIsolator)(nil) ) // isolateProvider wraps a [of.FeatureProvider] to execute its hooks along with any additional ones. -func isolateProvider(provider of.FeatureProvider, extraHooks []of.Hook) *hookIsolator { +func isolateProvider(provider NamedProvider, extraHooks []of.Hook) *hookIsolator { return &hookIsolator{ FeatureProvider: provider, hooks: append(provider.Hooks(), extraHooks...), + name: provider.Name(), } } // isolateProviderWithEvents wraps a [of.FeatureProvider] to execute its hooks along with any additional ones. This is // identical to [isolateProvider], but also this will also implement [of.EventHandler]. -func isolateProviderWithEvents(provider of.FeatureProvider, extraHooks []of.Hook) *eventHandlingHookIsolator { +func isolateProviderWithEvents(provider NamedProvider, extraHooks []of.Hook) *eventHandlingHookIsolator { return &eventHandlingHookIsolator{*isolateProvider(provider, extraHooks)} } @@ -51,6 +54,14 @@ func (h *eventHandlingHookIsolator) EventChannel() <-chan of.Event { return h.FeatureProvider.(of.EventHandler).EventChannel() } +func (h *hookIsolator) Name() string { + return h.name +} + +func (h *hookIsolator) unwrap() of.FeatureProvider { + return h.FeatureProvider +} + func (h *hookIsolator) Before(_ context.Context, hookContext of.HookContext, hookHints of.HookHints) (*of.EvaluationContext, error) { // Used for capturing the context and hints h.mu.Lock() diff --git a/openfeature/multi/isolation_test.go b/openfeature/multi/isolation_test.go index b940e972..83317c11 100644 --- a/openfeature/multi/isolation_test.go +++ b/openfeature/multi/isolation_test.go @@ -24,7 +24,10 @@ func Test_HookIsolator_BeforeCapturesData(t *testing.T) { ctrl := gomock.NewController(t) provider := of.NewMockFeatureProvider(ctrl) provider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) - isolator := isolateProvider(provider, []of.Hook{}) + isolator := isolateProvider(&namedProvider{ + FeatureProvider: provider, + name: "test-provider", + }, []of.Hook{}) assert.Zero(t, isolator.capturedContext) assert.Zero(t, isolator.capturedHints) evalCtx, err := isolator.Before(context.Background(), hookCtx, hookHints) @@ -32,13 +35,17 @@ func Test_HookIsolator_BeforeCapturesData(t *testing.T) { assert.NotNil(t, evalCtx) assert.Equal(t, hookCtx, isolator.capturedContext) assert.Equal(t, hookHints, isolator.capturedHints) + assert.Equal(t, "test-provider", isolator.Name()) } func Test_HookIsolator_Hooks_ReturnsSelf(t *testing.T) { ctrl := gomock.NewController(t) provider := of.NewMockFeatureProvider(ctrl) provider.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) - isolator := isolateProvider(provider, []of.Hook{}) + isolator := isolateProvider(&namedProvider{ + FeatureProvider: provider, + name: "test-provider", + }, []of.Hook{}) hooks := isolator.Hooks() assert.NotEmpty(t, hooks) assert.Same(t, isolator, hooks[0]) @@ -59,7 +66,10 @@ func Test_HookIsolator_ExecutesHooksDuringEvaluation_NoError(t *testing.T) { ProviderResolutionDetail: of.ProviderResolutionDetail{}, }) - isolator := isolateProvider(provider, nil) + isolator := isolateProvider(&namedProvider{ + FeatureProvider: provider, + name: "test-provider", + }, nil) result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) assert.True(t, result.Value) } @@ -75,7 +85,10 @@ func Test_HookIsolator_ExecutesHooksDuringEvaluation_BeforeErrorAbortsExecution( provider := of.NewMockFeatureProvider(ctrl) provider.EXPECT().Hooks().Return([]of.Hook{testHook}) - isolator := isolateProvider(provider, nil) + isolator := isolateProvider(&namedProvider{ + FeatureProvider: provider, + name: "test-provider", + }, nil) result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) assert.False(t, result.Value) } @@ -95,7 +108,10 @@ func Test_HookIsolator_ExecutesHooksDuringEvaluation_WithAfterError(t *testing.T ProviderResolutionDetail: of.ProviderResolutionDetail{}, }) - isolator := isolateProvider(provider, nil) + isolator := isolateProvider(&namedProvider{ + FeatureProvider: provider, + name: "test-provider", + }, nil) result := isolator.BooleanEvaluation(context.Background(), "test-flag", false, of.FlattenedContext{"targetingKey": "anon"}) assert.False(t, result.Value) } diff --git a/openfeature/multi/multiprovider.go b/openfeature/multi/multiprovider.go index cca5ffd2..8ad9a5e7 100644 --- a/openfeature/multi/multiprovider.go +++ b/openfeature/multi/multiprovider.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "slices" "strings" "sync" @@ -29,13 +30,10 @@ const ( ) type ( - // ProviderMap is an alias for a map containing unique names for each included [of.FeatureProvider] - ProviderMap = map[string]of.FeatureProvider - // Provider is an implementation of [of.FeatureProvider] that can execute multiple providers using various // strategies. Provider struct { - providers ProviderMap + providers []NamedProvider metadata of.Metadata initialized bool overallStatus of.State @@ -46,17 +44,24 @@ type ( strategyFunc StrategyFn[FlagTypes] // used for evaluating strategies logger *slog.Logger outboundEvents chan of.Event - inboundEvents chan namedEvent workerGroup sync.WaitGroup shutdownFunc context.CancelFunc globalHooks []of.Hook } - // NamedProvider allows for a unique name to be assigned to a provider during a multi-provider set up. + // NamedProvider extends [of.FeatureProvider] by adding a unique provider name. + NamedProvider interface { + of.FeatureProvider + // Name returns the unique name assigned to the provider. + Name() string + } + + // namedProvider allows for a unique name to be assigned to a provider during a multi-provider set up. // The name will be used when reporting errors & results to specify the provider associated with them. - NamedProvider struct { - Name string + namedProvider struct { of.FeatureProvider + name string + extraHooks []of.Hook } // Option function used for setting configuration via the options pattern @@ -75,20 +80,38 @@ type ( customStrategy StrategyConstructor logger *slog.Logger hooks []of.Hook - providerHooks map[string][]of.Hook + providers []*namedProvider customComparator Comparator } + + // namedEventHandler is a wrapper around an [of.EventHandler] that includes the provider name. + namedEventHandler struct { + of.EventHandler + name string + } ) +// Name returns the unique name assigned to the provider. +func (n *namedProvider) Name() string { + return n.name +} + +// unwrap returns the underlying [of.FeatureProvider] instance wrapped by this [namedProvider]. +func (n *namedProvider) unwrap() of.FeatureProvider { + return n.FeatureProvider +} + var ( stateValues map[of.State]int stateTable [3]of.State eventTypeToState map[of.EventType]of.State // Compile-time interface compliance checks - _ of.FeatureProvider = (*Provider)(nil) - _ of.EventHandler = (*Provider)(nil) - _ of.StateHandler = (*Provider)(nil) + _ of.FeatureProvider = (*Provider)(nil) + _ of.EventHandler = (*Provider)(nil) + _ of.ContextAwareStateHandler = (*Provider)(nil) + _ of.Tracker = (*Provider)(nil) + _ NamedProvider = (*namedProvider)(nil) ) // init Initialize "constants" used for event handling priorities and filtering. @@ -150,47 +173,36 @@ func WithCustomStrategy(s StrategyConstructor) Option { } // WithGlobalHooks sets the global hooks for the provider. These are [of.Hook] instances that affect ALL [of.FeatureProvider] -// instances. For hooks that target specific providers make sure to attach them to that provider directly, or use the -// [WithProviderHooks] [Option] if that provider does not provide its own hook functionality. +// instances. To apply hooks to specific providers, attach them directly to that provider, or include them in the [WithProvider] [Option] +// if the provider does not support its own hook functionality. func WithGlobalHooks(hooks ...of.Hook) Option { return func(conf *configuration) { conf.hooks = hooks } } -// WithProviderHooks sets [of.Hook] instances that execute only for a specific [of.FeatureProvider]. The providerName -// must match the unique provider name set during [multi.Provider] creation. This should only be used if you need hooks -// that execute around a specific provider, but that provider does not currently accept a way to set hooks. This [Option] -// can be used multiple times using unique provider names. Using a provider name that is not known will cause an error. -func WithProviderHooks(providerName string, hooks ...of.Hook) Option { +// WithProvider registers a specific [of.FeatureProvider] instance under the given providerName. The providerName +// must be unique and correspond to the name used when creating the [Provider]. Optional [of.Hook] instances +// may also be provided, which will execute only for this specific provider. This [Option] can be used multiple times +// with unique provider names to register multiple providers. The order in which options +// are provided determines the order in which the providers are registered and evaluated. +func WithProvider(providerName string, provider of.FeatureProvider, hooks ...of.Hook) Option { return func(conf *configuration) { - conf.providerHooks[providerName] = hooks + conf.providers = append(conf.providers, &namedProvider{ + name: providerName, + FeatureProvider: provider, + extraHooks: hooks, + }) } } // Multiprovider Implementation - -// toNamedProviderSlice converts the provided [ProviderMap] into a slice of [NamedProvider] instances -func toNamedProviderSlice(m ProviderMap) []*NamedProvider { - s := make([]*NamedProvider, 0, len(m)) - for name, provider := range m { - s = append(s, &NamedProvider{Name: name, FeatureProvider: provider}) - } - - return s -} - -func buildMetadata(m ProviderMap) of.Metadata { +func buildMetadata(m []NamedProvider) of.Metadata { var separator string var metaName strings.Builder metaName.WriteString("MultiProvider {") - names := make([]string, 0, len(m)) - for n := range m { - names = append(names, n) - } - slices.Sort(names) - for _, name := range names { - metaName.WriteString(fmt.Sprintf("%s%s: %s", separator, name, m[name].Metadata().Name)) + for _, p := range m { + metaName.WriteString(fmt.Sprintf("%s%s: %s", separator, p.Name(), p.Metadata().Name)) if separator == "" { separator = ", " } @@ -203,53 +215,53 @@ func buildMetadata(m ProviderMap) of.Metadata { } // NewProvider returns a new [multi.Provider] that acts as a unified interface of multiple providers for interaction. -func NewProvider(providerMap ProviderMap, evaluationStrategy EvaluationStrategy, options ...Option) (*Provider, error) { - if len(providerMap) == 0 { - return nil, errors.New("providerMap cannot be nil or empty") - } - +func NewProvider(evaluationStrategy EvaluationStrategy, options ...Option) (*Provider, error) { config := &configuration{ - logger: slog.New(slog.DiscardHandler), - providerHooks: make(map[string][]of.Hook), + logger: slog.New(slog.DiscardHandler), + providers: make([]*namedProvider, 0, 2), } for _, opt := range options { opt(config) } - providers := providerMap - collectedHooks := make([]of.Hook, 0, len(providerMap)) - for name, provider := range providerMap { + if len(config.providers) == 0 { + return nil, errors.New("no providers configured: at least one provider must be registered using WithProvider()") + } + + providers := make([]NamedProvider, 0, len(config.providers)) + collectedHooks := make([]of.Hook, 0, len(config.providers)) + for i, provider := range config.providers { // Validate Providers - if name == "" { - return nil, errors.New("provider name cannot be the empty string") + if provider.FeatureProvider == nil { + return nil, fmt.Errorf("provider %s at %d cannot be nil", provider.name, i) } - - if provider == nil { - return nil, fmt.Errorf("provider %s cannot be nil", name) + if provider.name == "" { + return nil, fmt.Errorf("provider name at %d cannot be the empty string", i) } // Wrap any providers that include hooks - if (len(provider.Hooks()) + len(config.providerHooks[name])) == 0 { + if (len(provider.Hooks()) + len(provider.extraHooks)) == 0 { + providers = append(providers, provider) continue } - var wrappedProvider of.FeatureProvider - if _, ok := provider.(of.EventHandler); ok { - wrappedProvider = isolateProviderWithEvents(provider, config.providerHooks[name]) + var wrappedProvider NamedProvider + if _, ok := provider.FeatureProvider.(of.EventHandler); ok { + wrappedProvider = isolateProviderWithEvents(provider, provider.extraHooks) } else { - wrappedProvider = isolateProvider(provider, config.providerHooks[name]) + wrappedProvider = isolateProvider(provider, provider.extraHooks) } - providers[name] = wrappedProvider - collectedHooks = append(collectedHooks, wrappedProvider.Hooks()...) + providers = append(providers, wrappedProvider) + collectedHooks = slices.Concat(collectedHooks, wrappedProvider.Hooks()) } multiProvider := &Provider{ providers: providers, outboundEvents: make(chan of.Event, len(providers)), logger: config.logger, - metadata: buildMetadata(providerMap), + metadata: buildMetadata(providers), overallStatus: of.NotReadyState, providerStatus: make(map[string]of.State, len(providers)), globalHooks: append(config.hooks, collectedHooks...), @@ -276,12 +288,7 @@ func NewProvider(providerMap ProviderMap, evaluationStrategy EvaluationStrategy, } // Providers returns slice of providers wrapped in [NamedProvider] structs. -func (p *Provider) Providers() []*NamedProvider { - return toNamedProviderSlice(p.providers) -} - -// ProvidersByName Returns the internal [ProviderMap]. -func (p *Provider) ProvidersByName() ProviderMap { +func (p *Provider) Providers() []NamedProvider { return p.providers } @@ -350,106 +357,137 @@ func (p *Provider) ObjectEvaluation(ctx context.Context, flag string, defaultVal // Init will run the initialize method for all internal [of.FeatureProvider] instances and aggregate any errors. func (p *Provider) Init(evalCtx of.EvaluationContext) error { - var eg errgroup.Group + return p.InitWithContext(context.Background(), evalCtx) +} + +// InitWithContext will run the initialize method for all internal [of.FeatureProvider] instances and aggregate any errors. +func (p *Provider) InitWithContext(ctx context.Context, evalCtx of.EvaluationContext) error { + eg, ctx := errgroup.WithContext(ctx) // wrapper type used only for initialization of event listener workers - type namedEventHandler struct { - of.EventHandler - name string - } - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "start initialization") - p.inboundEvents = make(chan namedEvent, len(p.providers)) + p.logger.LogAttrs(ctx, slog.LevelDebug, "start initialization") handlers := make(chan namedEventHandler, len(p.providers)) - for name, provider := range p.providers { + for _, provider := range p.providers { + name := provider.Name() // Initialize each provider to not ready state. No locks required there are no workers running p.updateProviderState(name, of.NotReadyState) l := p.logger.With(slog.String(MetadataProviderName, name)) prov := provider eg.Go(func() error { - l.LogAttrs(context.Background(), slog.LevelDebug, "starting initialization") - stateHandle, ok := prov.(of.StateHandler) - if !ok { - l.LogAttrs(context.Background(), slog.LevelDebug, "StateHandle not implemented, skipping initialization") - } else if err := stateHandle.Init(evalCtx); err != nil { - l.LogAttrs(context.Background(), slog.LevelError, "initialization failed", slog.Any("error", err)) - return &ProviderError{ - Err: err, - ProviderName: name, + l.LogAttrs(ctx, slog.LevelDebug, "starting initialization") + if stateHandle, ok := tryAs[of.StateHandler](prov); ok { + var err error + if contextAwareHandle, ok := stateHandle.(of.ContextAwareStateHandler); ok { + err = contextAwareHandle.InitWithContext(ctx, evalCtx) + } else { + err = stateHandle.Init(evalCtx) } + + if err != nil { + l.LogAttrs(ctx, slog.LevelError, "initialization failed", slog.Any("error", err)) + p.updateProviderState(name, of.ErrorState) + return &ProviderError{ + err: err, + ProviderName: name, + } + } + } else { + l.LogAttrs(ctx, slog.LevelDebug, "StateHandle not implemented, skipping initialization") } - l.LogAttrs(context.Background(), slog.LevelDebug, "initialization successful") - if eventer, ok := provider.(of.EventHandler); ok { - l.LogAttrs(context.Background(), slog.LevelDebug, "detected EventHandler implementation") + l.LogAttrs(ctx, slog.LevelDebug, "initialization successful") + if eventer, ok := tryAs[of.EventHandler](prov); ok { + l.LogAttrs(ctx, slog.LevelDebug, "detected EventHandler implementation") handlers <- namedEventHandler{eventer, name} - } else { - // Do not yet update providers that need event handling - p.updateProviderState(name, of.ReadyState) } + p.updateProviderState(name, of.ReadyState) return nil }) } if err := eg.Wait(); err != nil { var pErr *ProviderError - if errors.As(err, &pErr) { - // Update provider status to error, no event needs to be emitted yet - p.updateProviderState(pErr.ProviderName, of.ErrorState) - } else { + if !errors.As(err, &pErr) { pErr = &ProviderError{ - Err: err, + err: err, ProviderName: "unknown", } - p.setStatus(of.ErrorState) } - return err + p.setStatus(of.ErrorState) + return pErr } close(handlers) workerCtx, shutdownFunc := context.WithCancel(context.Background()) - for h := range handlers { - go p.startListening(workerCtx, h.name, h.EventHandler, &p.workerGroup) - } p.shutdownFunc = shutdownFunc - p.workerGroup.Add(1) - go func() { - workerLogger := p.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker")) - defer p.workerGroup.Done() - for e := range p.inboundEvents { - l := workerLogger.With( - slog.String(MetadataProviderName, e.providerName), - slog.String(MetadataProviderType, e.ProviderName), - ) - l.LogAttrs(context.Background(), slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType))) - if p.updateProviderStateFromEvent(e) { - p.outboundEvents <- e.Event - l.LogAttrs(context.Background(), slog.LevelDebug, "forwarded state update event") - } else { - l.LogAttrs(context.Background(), slog.LevelDebug, "total state not updated, inbound event will not be emitted") - } - } - }() + if len(handlers) > 0 { + p.workerGroup.Add(1) + go p.forwardProviderEvents(workerCtx, handlers) + } else { + // we don't emit any events so we can just close the channel + close(p.outboundEvents) + } p.setStatus(of.ReadyState) p.initialized = true return nil } -// startListening is intended to be called on a per-provider basis as a goroutine to listen to events from a provider -// implementing [of.EventHandler]. -func (p *Provider) startListening(ctx context.Context, name string, h of.EventHandler, wg *sync.WaitGroup) { - wg.Add(1) - defer wg.Done() - for { - select { - case e := <-h.EventChannel(): - e.EventMetadata[MetadataProviderName] = name - e.EventMetadata[MetadataProviderType] = h.(of.FeatureProvider).Metadata().Name - p.inboundEvents <- namedEvent{ - Event: e, - providerName: name, +// forwardProviderEvents establishes an event forwarding pipeline that collects events from multiple provider +// event handlers and forwards them to the multiprovider's outbound event channel. It spawns a goroutine for +// each provider handler to listen for events, aggregates them through an internal pipe, and selectively forwards +// events that result in state changes. The function blocks until workerCtx is cancelled or all provider event +// channels are closed, ensuring proper cleanup by closing the outbound channel when complete. +func (p *Provider) forwardProviderEvents(workerCtx context.Context, handlers chan namedEventHandler) { + defer p.workerGroup.Done() + defer close(p.outboundEvents) + + workerLogger := p.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker")) + pipe := make(chan namedEvent) + var wg sync.WaitGroup + for ch := range handlers { + wg.Add(1) + go func(ctx context.Context, h of.EventHandler, name string, out chan<- namedEvent) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case e, ok := <-h.EventChannel(): + if !ok { + return + } + if e.EventMetadata == nil { + e.EventMetadata = make(map[string]any) + } + e.EventMetadata[MetadataProviderName] = name + if p, ok := h.(of.FeatureProvider); ok { + e.EventMetadata[MetadataProviderType] = p.Metadata().Name + } + out <- namedEvent{ + Event: e, + providerName: name, + } + } } - case <-ctx.Done(): - return + }(workerCtx, ch.EventHandler, ch.name, pipe) + } + + go func() { + wg.Wait() + close(pipe) + }() + + for e := range pipe { + l := workerLogger.With( + slog.String(MetadataProviderName, e.providerName), + slog.String(MetadataProviderType, e.ProviderName), + ) + l.LogAttrs(workerCtx, slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType))) + if p.updateProviderStateFromEvent(e) { + p.outboundEvents <- e.Event + l.LogAttrs(workerCtx, slog.LevelDebug, "forwarded state update event") + } else { + l.LogAttrs(workerCtx, slog.LevelDebug, "total state not updated, inbound event will not be emitted") } } } @@ -475,7 +513,10 @@ func (p *Provider) updateProviderStateFromEvent(e namedEvent) bool { if e.EventType == of.ProviderConfigChange { p.logger.LogAttrs(context.Background(), slog.LevelDebug, "ProviderConfigChange event", slog.String("event-message", e.Message)) } - logProviderState(p.logger, e, p.providerStatus[e.providerName]) + p.providerStatusLock.Lock() + previousState := p.providerStatus[e.providerName] + p.providerStatusLock.Unlock() + logProviderState(p.logger, e, previousState) return p.updateProviderState(e.providerName, eventTypeToState[e.EventType]) } @@ -512,40 +553,54 @@ func logProviderState(l *slog.Logger, e namedEvent, previousState of.State) { // Shutdown Shuts down all internal [of.FeatureProvider] instances and internal event listeners func (p *Provider) Shutdown() { + ctx := context.Background() + err := p.ShutdownWithContext(ctx) + if err != nil { + p.logger.LogAttrs(ctx, slog.LevelWarn, "error during shutdown", slog.Any("error", err)) + } +} + +// ShutdownWithContext shuts down all internal [of.FeatureProvider] instances and internal event listeners +func (p *Provider) ShutdownWithContext(ctx context.Context) error { if !p.initialized { // Don't do anything if we were never initialized - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "provider not initialized, skipping shutdown") - return + p.logger.LogAttrs(ctx, slog.LevelDebug, "provider not initialized, skipping shutdown") + return nil } + + p.logger.LogAttrs(ctx, slog.LevelDebug, "starting provider shutdown") // Stop all event listener workers, shutdown events should not affect overall state p.shutdownFunc() - // Stop forwarding worker - close(p.inboundEvents) - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "triggered worker shutdown") - // Wait for workers to stop - p.workerGroup.Wait() - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "worker shutdown completed") - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "starting provider shutdown") - var wg sync.WaitGroup - for _, provider := range p.providers { - wg.Add(1) + meg := multiErrGroup{} - go func(p of.FeatureProvider) { - defer wg.Done() - if stateHandle, ok := p.(of.StateHandler); ok { - stateHandle.Shutdown() - } - }(provider) + for _, provider := range p.providers { + name := provider.Name() + if stateHandle, ok := tryAs[of.StateHandler](provider); ok { + meg.Go(func() error { + if contextAwareHandle, ok := stateHandle.(of.ContextAwareStateHandler); ok { + if err := contextAwareHandle.ShutdownWithContext(ctx); err != nil { + return &ProviderError{ProviderName: name, err: err} + } + } else { + stateHandle.Shutdown() + } + return nil + }) + } } - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "waiting for provider shutdown completion") - wg.Wait() - p.logger.LogAttrs(context.Background(), slog.LevelDebug, "provider shutdown completed") + p.logger.LogAttrs(ctx, slog.LevelDebug, "waiting for provider shutdown completion") + errs := meg.Wait() + // Stop forwarding worker + p.logger.LogAttrs(ctx, slog.LevelDebug, "triggered worker shutdown") + // Wait for workers to stop + p.workerGroup.Wait() + p.logger.LogAttrs(ctx, slog.LevelDebug, "worker shutdown completed") p.setStatus(of.NotReadyState) - close(p.outboundEvents) - p.outboundEvents = nil - p.inboundEvents = nil p.initialized = false + p.logger.LogAttrs(ctx, slog.LevelDebug, "provider shutdown completed") + + return errs } // Status provides the current state of the [multi.Provider]. @@ -566,3 +621,47 @@ func (p *Provider) setStatus(state of.State) { func (p *Provider) EventChannel() <-chan of.Event { return p.outboundEvents } + +// Track implements the [of.Tracker] interface by forwarding tracking calls to all internal providers that +// are in ready state and implement the [of.Tracker] interface. +func (p *Provider) Track(ctx context.Context, trackingEventName string, evaluationContext of.EvaluationContext, details of.TrackingEventDetails) { + if !p.initialized { + // Don't do anything if we were never initialized + p.logger.LogAttrs(ctx, slog.LevelDebug, "provider not initialized, skipping tracking", slog.String("tracking-event", trackingEventName)) + return + } + p.providerStatusLock.Lock() + statuses := maps.Clone(p.providerStatus) + p.providerStatusLock.Unlock() + providers := make([]NamedProvider, 0, len(p.providers)) + for _, p := range p.providers { + if statuses[p.Name()] == of.ReadyState { + providers = append(providers, p) + } + } + for _, provider := range providers { + if tracker, ok := tryAs[of.Tracker](provider); ok { + tracker.Track(ctx, trackingEventName, evaluationContext, details) + } + } +} + +// tryAs attempts to extract and type-assert the underlying [of.FeatureProvider] from a [NamedProvider]. +// It first checks if the provider implements an unwrap() method to access the wrapped provider, +// then attempts to cast that provider to type T. Returns the casted value and true if successful, +// or the zero value of T and false if the provider doesn't support unwrapping or doesn't implement type T. +// This is used internally to check if wrapped providers implement optional interfaces like +// [of.StateHandler], [of.EventHandler], or [of.Tracker]. +func tryAs[T any](p NamedProvider) (T, bool) { + var v T + + unwrapped, ok := p.(interface { + unwrap() of.FeatureProvider + }) + if !ok { + return v, false + } + + v, ok = unwrapped.unwrap().(T) + return v, ok +} diff --git a/openfeature/multi/multiprovider_test.go b/openfeature/multi/multiprovider_test.go index 042688c4..ef258d0c 100644 --- a/openfeature/multi/multiprovider_test.go +++ b/openfeature/multi/multiprovider_test.go @@ -3,7 +3,6 @@ package multi import ( "context" "errors" - "regexp" "testing" "time" @@ -18,99 +17,67 @@ func TestMultiProvider_ProvidersMethod(t *testing.T) { testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - - mp, err := NewProvider(providers, StrategyFirstSuccess) + mp, err := NewProvider(StrategyFirstSuccess, WithProvider("provider1", testProvider1), WithProvider("provider2", testProvider2)) require.NoError(t, err) p := mp.Providers() assert.Len(t, p, 2) - assert.Regexp(t, regexp.MustCompile("provider[1-2]"), p[0].Name) - assert.NotNil(t, p[0].FeatureProvider) - assert.Regexp(t, regexp.MustCompile("provider[1-2]"), p[1].Name) - assert.NotNil(t, p[1].FeatureProvider) + assert.NotNil(t, p[0]) + assert.Implements(t, (*of.FeatureProvider)(nil), p[0]) + assert.Equal(t, "provider1", p[0].Name()) + assert.NotNil(t, p[1]) + assert.Implements(t, (*of.FeatureProvider)(nil), p[1]) + assert.Equal(t, "provider2", p[1].Name()) } func TestMultiProvider_NewMultiProvider(t *testing.T) { t.Run("nil providerMap returns an error", func(t *testing.T) { - _, err := NewProvider(nil, StrategyFirstMatch) + _, err := NewProvider(StrategyFirstMatch) require.Errorf(t, err, "providerMap cannot be nil or empty") }) t.Run("naming a provider the empty string returns an error", func(t *testing.T) { - providers := make(ProviderMap) - providers[""] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - _, err := NewProvider(providers, StrategyFirstMatch) + _, err := NewProvider(StrategyFirstMatch, WithProvider("", imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}))) require.Errorf(t, err, "provider name cannot be the empty string") }) t.Run("nil provider within map returns an error", func(t *testing.T) { - providers := make(ProviderMap) - providers["provider1"] = nil - _, err := NewProvider(providers, StrategyFirstMatch) + _, err := NewProvider(StrategyFirstMatch, WithProvider("provider1", nil)) require.Errorf(t, err, "provider provider1 cannot be nil") }) t.Run("unknown evaluation strategyFunc returns an error", func(t *testing.T) { - providers := make(ProviderMap) - providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - _, err := NewProvider(providers, "unknown") + _, err := NewProvider("unknown", WithProvider("provider1", imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}))) require.Errorf(t, err, "unknown is an unknown evaluation strategyFunc") }) t.Run("setting custom strategyFunc without custom strategyFunc option returns error", func(t *testing.T) { - providers := make(ProviderMap) - providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - _, err := NewProvider(providers, StrategyCustom) + _, err := NewProvider(StrategyCustom, WithProvider("provider1", imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}))) require.Errorf(t, err, "A custom strategyFunc must be set via an option if StrategyCustom is set") }) t.Run("success", func(t *testing.T) { - providers := make(ProviderMap) - providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - mp, err := NewProvider(providers, StrategyComparison) + mp, err := NewProvider(StrategyComparison, WithProvider("provider1", imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}))) require.NoError(t, err) assert.NotZero(t, mp) }) t.Run("success with custom provider", func(t *testing.T) { - providers := make(ProviderMap) - providers["provider1"] = imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - mp, err := NewProvider(providers, StrategyCustom, WithCustomStrategy(func(providers []*NamedProvider) StrategyFn[FlagTypes] { + mp, err := NewProvider(StrategyCustom, WithCustomStrategy(func(providers []NamedProvider) StrategyFn[FlagTypes] { return func(ctx context.Context, flag string, defaultValue FlagTypes, evalCtx of.FlattenedContext) of.GenericResolutionDetail[FlagTypes] { return of.GenericResolutionDetail[FlagTypes]{ Value: defaultValue, ProviderResolutionDetail: of.ProviderResolutionDetail{Reason: of.UnknownReason}, } } - })) + }), + WithProvider("provider1", imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{})), + ) require.NoError(t, err) assert.NotZero(t, mp) }) } -func TestMultiProvider_ProvidersByNamesMethod(t *testing.T) { - testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - - mp, err := NewProvider(providers, StrategyFirstMatch) - require.NoError(t, err) - - p := mp.ProvidersByName() - - assert.Len(t, p, 2) - require.Contains(t, p, "provider1") - assert.Equal(t, p["provider1"], testProvider1) - require.Contains(t, p, "provider2") - assert.Equal(t, p["provider2"], testProvider2) -} - func TestMultiProvider_MetaData(t *testing.T) { t.Run("two providers", func(t *testing.T) { testProvider1 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) @@ -121,11 +88,11 @@ func TestMultiProvider_MetaData(t *testing.T) { }) testProvider2.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - - mp, err := NewProvider(providers, StrategyFirstSuccess) + mp, err := NewProvider( + StrategyFirstSuccess, + WithProvider("provider1", testProvider1), + WithProvider("provider2", testProvider2), + ) require.NoError(t, err) metadata := mp.Metadata() @@ -147,12 +114,12 @@ func TestMultiProvider_MetaData(t *testing.T) { }) testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - providers["provider3"] = testProvider3 - - mp, err := NewProvider(providers, StrategyFirstSuccess) + mp, err := NewProvider( + StrategyFirstSuccess, + WithProvider("provider1", testProvider1), + WithProvider("provider2", testProvider2), + WithProvider("provider3", testProvider3), + ) require.NoError(t, err) metadata := mp.Metadata() @@ -187,12 +154,12 @@ func TestMultiProvider_Init(t *testing.T) { testProvider3.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - providers["provider3"] = testProvider3 - - mp, err := NewProvider(providers, StrategyFirstMatch) + mp, err := NewProvider( + StrategyFirstMatch, + WithProvider("provider1", testProvider1), + WithProvider("provider2", testProvider2), + WithProvider("provider3", testProvider3), + ) require.NoError(t, err) t.Cleanup(func() { @@ -228,12 +195,12 @@ func TestMultiProvider_InitErrorWithProvider(t *testing.T) { testProvider1.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) testProvider2 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - providers["provider3"] = testProvider3 - - mp, err := NewProvider(providers, StrategyFirstMatch) + mp, err := NewProvider( + StrategyFirstMatch, + WithProvider("provider1", testProvider1), + WithProvider("provider2", testProvider2), + WithProvider("provider3", testProvider3), + ) require.NoError(t, err) attributes := map[string]any{ @@ -256,11 +223,12 @@ func TestMultiProvider_Shutdown_WithoutInit(t *testing.T) { testProvider3.EXPECT().Metadata().Return(of.Metadata{Name: "MockProvider"}) testProvider3.EXPECT().Hooks().Return([]of.Hook{}).MinTimes(1) - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - providers["provider3"] = testProvider3 - mp, err := NewProvider(providers, StrategyFirstMatch) + mp, err := NewProvider( + StrategyFirstMatch, + WithProvider("provider1", testProvider1), + WithProvider("provider2", testProvider2), + WithProvider("provider3", testProvider3), + ) require.NoError(t, err) mp.Shutdown() @@ -287,11 +255,12 @@ func TestMultiProvider_Shutdown_WithInit(t *testing.T) { handledHandler, } - providers := make(ProviderMap) - providers["provider1"] = testProvider1 - providers["provider2"] = testProvider2 - providers["provider3"] = testProvider3 - mp, err := NewProvider(providers, StrategyFirstMatch) + mp, err := NewProvider( + StrategyFirstMatch, + WithProvider("provider1", testProvider1), + WithProvider("provider2", testProvider2), + WithProvider("provider3", testProvider3), + ) require.NoError(t, err) evalCtx := of.NewTargetlessEvaluationContext(map[string]any{ "foo": "bar", @@ -354,34 +323,40 @@ func TestMultiProvider_StateUpdateWithSameTypeProviders(t *testing.T) { primaryProvider := newMockProviderWithEvents(ctrl, "MockProvider") secondaryProvider := newMockProviderWithEvents(ctrl, "MockProvider") - providers := ProviderMap{ - "primary": primaryProvider, - "secondary": secondaryProvider, - } - - multiProvider, err := NewProvider(providers, StrategyFirstMatch) + mp, err := NewProvider( + StrategyFirstMatch, + WithProvider("primary", primaryProvider), + WithProvider("secondary", secondaryProvider), + ) if err != nil { t.Fatalf("failed to create multi-provider: %v", err) } - t.Cleanup(multiProvider.Shutdown) + t.Cleanup(mp.Shutdown) // Initialize the provider ctx := of.NewEvaluationContext("test", nil) - if err := multiProvider.Init(ctx); err != nil { + if err := mp.Init(ctx); err != nil { t.Fatalf("failed to initialize multi-provider: %v", err) } primaryProvider.EmitEvent(of.ProviderError, "fail to fetch data") secondaryProvider.EmitEvent(of.ProviderReady, "rev 1") - - time.Sleep(200 * time.Millisecond) + // wait for processing + require.Eventually(t, func() bool { + select { + case <-mp.outboundEvents: + return true + default: + return false + } + }, time.Second, 50*time.Millisecond, "expected event was not emitted within timeout") // Check the state after the error event - multiProvider.providerStatusLock.Lock() - primaryState := multiProvider.providerStatus["primary"] - secondaryState := multiProvider.providerStatus["secondary"] - numProviders := len(multiProvider.providerStatus) - multiProvider.providerStatusLock.Unlock() + mp.providerStatusLock.Lock() + primaryState := mp.providerStatus["primary"] + secondaryState := mp.providerStatus["secondary"] + numProviders := len(mp.providerStatus) + mp.providerStatusLock.Unlock() if primaryState != of.ErrorState { t.Errorf("Expected primary-mock state to be ERROR after emitting error event, got %s", primaryState) @@ -396,12 +371,140 @@ func TestMultiProvider_StateUpdateWithSameTypeProviders(t *testing.T) { } } +func TestMultiProvider_Track(t *testing.T) { + t.Run("forwards tracking to all ready providers that implement Tracker", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + provider1 := newMockProviderWithEvents(ctrl, "provider1") + provider2 := newMockProviderWithEvents(ctrl, "provider2") + provider3 := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) // Does not implement Tracker + + mp, err := NewProvider( + StrategyFirstSuccess, + WithProvider("provider1", provider1), + WithProvider("provider2", provider2), + WithProvider("provider3", provider3), + ) + require.NoError(t, err) + t.Cleanup(mp.Shutdown) + + evalCtx := of.NewEvaluationContext("user-123", map[string]any{"plan": "premium"}) + err = mp.Init(evalCtx) + require.NoError(t, err) + + trackingEventName := "button-clicked" + details := of.NewTrackingEventDetails(42.0).Add("currency", "USD") + + ctx := t.Context() + // Expect Track to be called on providers that implement Tracker + provider1.MockTracker.EXPECT().Track(ctx, trackingEventName, evalCtx, details).Times(1) + provider2.MockTracker.EXPECT().Track(ctx, trackingEventName, evalCtx, details).Times(1) + + mp.Track(ctx, trackingEventName, evalCtx, details) + }) + + t.Run("does not track when provider is not initialized", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + provider1 := newMockProviderWithEvents(ctrl, "provider1") + // manual shutdown on cleanup because multi-provider won't be initialized + t.Cleanup(provider1.Shutdown) + + mp, err := NewProvider(StrategyFirstSuccess, WithProvider("provider1", provider1)) + require.NoError(t, err) + t.Cleanup(mp.Shutdown) + + // Don't initialize the multi-provider + ctx := context.Background() + trackingEventName := "button-clicked" + evalCtx := of.NewEvaluationContext("user-123", map[string]any{}) + details := of.TrackingEventDetails{} + + // Should not call Track on provider + provider1.MockTracker.EXPECT().Track(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + mp.Track(ctx, trackingEventName, evalCtx, details) + }) + + t.Run("only tracks on providers in ready state", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + readyProvider := newMockProviderWithEvents(ctrl, "ready-provider") + errorProvider := newMockProviderWithEvents(ctrl, "error-provider") + + mp, err := NewProvider( + StrategyFirstSuccess, + WithProvider("ready-provider", readyProvider), + WithProvider("error-provider", errorProvider), + ) + require.NoError(t, err) + t.Cleanup(mp.Shutdown) + + evalCtx := of.NewEvaluationContext("user-456", map[string]any{}) + err = mp.Init(evalCtx) + require.NoError(t, err) + + // Simulate error state for one provider + errorProvider.EmitEvent(of.ProviderError, "error") + + // wait for event processing + require.Eventually(t, func() bool { + select { + case e := <-mp.outboundEvents: + return e.ProviderName == "error-provider" && e.EventType == of.ProviderError + default: + return false + } + }, time.Second, 50*time.Millisecond, "expected event was not emitted within timeout") + + trackingEventName := "page-view" + details := of.TrackingEventDetails{} + + ctx := t.Context() + readyProvider.MockTracker.EXPECT().Track(ctx, trackingEventName, evalCtx, details).Times(1) + errorProvider.MockTracker.EXPECT().Track(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + mp.Track(ctx, trackingEventName, evalCtx, details) + }) + + t.Run("handles providers that don't implement Tracker", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + trackerProvider := newMockProviderWithEvents(ctrl, "tracker-provider") + nonTrackerProvider := imp.NewInMemoryProvider(map[string]imp.InMemoryFlag{}) + + mp, err := NewProvider( + StrategyFirstSuccess, + WithProvider("tracker-provider", trackerProvider), + WithProvider("non-tracker", nonTrackerProvider), + ) + require.NoError(t, err) + t.Cleanup(mp.Shutdown) + + evalCtx := of.NewEvaluationContext("user-789", map[string]any{}) + err = mp.Init(evalCtx) + require.NoError(t, err) + + trackingEventName := "conversion" + details := of.NewTrackingEventDetails(99.99) + + ctx := t.Context() + trackerProvider.MockTracker.EXPECT().Track(ctx, trackingEventName, evalCtx, details).Times(1) + mp.Track(ctx, trackingEventName, evalCtx, details) + }) +} + var _ of.StateHandler = (*mockProviderWithEvents)(nil) -// mockProviderWithEvents wraps a mock provider to add EventHandler capability +// mockProviderWithEvents wraps a mock provider to add EventHandler and optional Tracker capability type mockProviderWithEvents struct { *of.MockFeatureProvider *of.MockStateHandler + *of.MockTracker eventChannel chan of.Event metadata of.Metadata } @@ -409,6 +512,7 @@ type mockProviderWithEvents struct { func newMockProviderWithEvents(ctrl *gomock.Controller, name string) *mockProviderWithEvents { mockProvider := of.NewMockFeatureProvider(ctrl) mockStateHandler := of.NewMockStateHandler(ctrl) + mockTracker := of.NewMockTracker(ctrl) eventChan := make(chan of.Event, 10) metadata := of.Metadata{Name: name} @@ -434,6 +538,7 @@ func newMockProviderWithEvents(ctrl *gomock.Controller, name string) *mockProvid MockStateHandler: mockStateHandler, eventChannel: eventChan, metadata: metadata, + MockTracker: mockTracker, } } @@ -460,3 +565,9 @@ func (m *mockProviderWithEvents) EmitEvent(eventType of.EventType, message strin }, } } + +func (m *mockProviderWithEvents) Track(ctx context.Context, trackingEventName string, evaluationContext of.EvaluationContext, details of.TrackingEventDetails) { + if m.MockTracker != nil { + m.MockTracker.Track(ctx, trackingEventName, evaluationContext, details) + } +} diff --git a/openfeature/multi/strategies.go b/openfeature/multi/strategies.go index cadf9edf..0cd4da02 100644 --- a/openfeature/multi/strategies.go +++ b/openfeature/multi/strategies.go @@ -55,7 +55,7 @@ type ( StrategyFn[T FlagTypes] func(ctx context.Context, flag string, defaultValue T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] // StrategyConstructor defines the signature for the function that will be called to retrieve the closure that acts // as the custom strategy implementation. This function should return a [StrategyFn] - StrategyConstructor func(providers []*NamedProvider) StrategyFn[FlagTypes] + StrategyConstructor func(providers []NamedProvider) StrategyFn[FlagTypes] ) // Common Components @@ -136,7 +136,7 @@ func BuildDefaultResult[R FlagTypes](strategy EvaluationStrategy, defaultValue R // Evaluate is a generic method used to resolve a flag from a single [NamedProvider] without losing type information. // This method is exported for those writing their own custom [StrategyFn]. Since any is an allowed [FlagTypes] this can // be set to any type, but this should be done with care outside the specified primitive [FlagTypes] -func Evaluate[T FlagTypes](ctx context.Context, provider *NamedProvider, flag string, defaultVal T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { +func Evaluate[T FlagTypes](ctx context.Context, provider NamedProvider, flag string, defaultVal T, flatCtx of.FlattenedContext) of.GenericResolutionDetail[T] { var resolution of.GenericResolutionDetail[T] switch v := any(defaultVal).(type) { case bool: @@ -165,7 +165,7 @@ func Evaluate[T FlagTypes](ctx context.Context, provider *NamedProvider, flag st resolution.FlagMetadata = make(of.FlagMetadata, 2) } - resolution.FlagMetadata[MetadataProviderName] = provider.Name + resolution.FlagMetadata[MetadataProviderName] = provider.Name() resolution.FlagMetadata[MetadataProviderType] = provider.Metadata().Name return resolution