Skip to content

Commit 0769cea

Browse files
committed
fix race issues
Signed-off-by: Roman Dmytrenko <[email protected]>
1 parent 0eb65f7 commit 0769cea

File tree

2 files changed

+35
-35
lines changed

2 files changed

+35
-35
lines changed

openfeature/multi/multiprovider.go

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func NewProvider(providerMap ProviderMap, evaluationStrategy EvaluationStrategy,
219219
opt(config)
220220
}
221221

222-
providers := providerMap
222+
providers := make(ProviderMap, len(providerMap))
223223
collectedHooks := make([]of.Hook, 0, len(providerMap))
224224
for name, provider := range providerMap {
225225
// Validate Providers
@@ -233,6 +233,7 @@ func NewProvider(providerMap ProviderMap, evaluationStrategy EvaluationStrategy,
233233

234234
// Wrap any providers that include hooks
235235
if (len(provider.Hooks()) + len(config.providerHooks[name])) == 0 {
236+
providers[name] = provider
236237
continue
237238
}
238239

@@ -251,7 +252,7 @@ func NewProvider(providerMap ProviderMap, evaluationStrategy EvaluationStrategy,
251252
providers: providers,
252253
outboundEvents: make(chan of.Event, len(providers)),
253254
logger: config.logger,
254-
metadata: buildMetadata(providerMap),
255+
metadata: buildMetadata(providers),
255256
overallStatus: of.NotReadyState,
256257
providerStatus: make(map[string]of.State, len(providers)),
257258
globalHooks: append(config.hooks, collectedHooks...),
@@ -373,6 +374,7 @@ func (p *Provider) Init(evalCtx of.EvaluationContext) error {
373374
l.LogAttrs(context.Background(), slog.LevelDebug, "StateHandle not implemented, skipping initialization")
374375
} else if err := stateHandle.Init(evalCtx); err != nil {
375376
l.LogAttrs(context.Background(), slog.LevelError, "initialization failed", slog.Any("error", err))
377+
p.updateProviderState(name, of.ErrorState)
376378
return &ProviderError{
377379
Err: err,
378380
ProviderName: name,
@@ -390,10 +392,7 @@ func (p *Provider) Init(evalCtx of.EvaluationContext) error {
390392

391393
if err := eg.Wait(); err != nil {
392394
var pErr *ProviderError
393-
if errors.As(err, &pErr) {
394-
// Update provider status to error, no event needs to be emitted yet
395-
p.updateProviderState(pErr.ProviderName, of.ErrorState)
396-
} else {
395+
if !errors.As(err, &pErr) {
397396
pErr = &ProviderError{
398397
Err: err,
399398
ProviderName: "unknown",
@@ -411,23 +410,30 @@ func (p *Provider) Init(evalCtx of.EvaluationContext) error {
411410
p.shutdownFunc = shutdownFunc
412411

413412
p.workerGroup.Add(1)
414-
go func() {
413+
go func(ctx context.Context) {
415414
workerLogger := p.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker"))
416415
defer p.workerGroup.Done()
417-
for e := range p.inboundEvents {
418-
l := workerLogger.With(
419-
slog.String(MetadataProviderName, e.providerName),
420-
slog.String(MetadataProviderType, e.ProviderName),
421-
)
422-
l.LogAttrs(context.Background(), slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType)))
423-
if p.updateProviderStateFromEvent(e) {
424-
p.outboundEvents <- e.Event
425-
l.LogAttrs(context.Background(), slog.LevelDebug, "forwarded state update event")
426-
} else {
427-
l.LogAttrs(context.Background(), slog.LevelDebug, "total state not updated, inbound event will not be emitted")
416+
417+
for {
418+
select {
419+
case <-ctx.Done():
420+
close(p.outboundEvents)
421+
return
422+
case e := <-p.inboundEvents:
423+
l := workerLogger.With(
424+
slog.String(MetadataProviderName, e.providerName),
425+
slog.String(MetadataProviderType, e.ProviderName),
426+
)
427+
l.LogAttrs(context.Background(), slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType)))
428+
if p.updateProviderStateFromEvent(e) {
429+
p.outboundEvents <- e.Event
430+
l.LogAttrs(context.Background(), slog.LevelDebug, "forwarded state update event")
431+
} else {
432+
l.LogAttrs(context.Background(), slog.LevelDebug, "total state not updated, inbound event will not be emitted")
433+
}
428434
}
429435
}
430-
}()
436+
}(workerCtx)
431437

432438
p.setStatus(of.ReadyState)
433439
p.initialized = true
@@ -519,13 +525,7 @@ func (p *Provider) Shutdown() {
519525
}
520526
// Stop all event listener workers, shutdown events should not affect overall state
521527
p.shutdownFunc()
522-
// Stop forwarding worker
523-
close(p.inboundEvents)
524-
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "triggered worker shutdown")
525-
// Wait for workers to stop
526-
p.workerGroup.Wait()
527-
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "worker shutdown completed")
528-
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "starting provider shutdown")
528+
529529
var wg sync.WaitGroup
530530
for _, provider := range p.providers {
531531
wg.Add(1)
@@ -540,9 +540,15 @@ func (p *Provider) Shutdown() {
540540

541541
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "waiting for provider shutdown completion")
542542
wg.Wait()
543+
// Stop forwarding worker
544+
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "triggered worker shutdown")
545+
// Wait for workers to stop
546+
p.workerGroup.Wait()
547+
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "worker shutdown completed")
548+
close(p.inboundEvents)
549+
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "starting provider shutdown")
543550
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "provider shutdown completed")
544551
p.setStatus(of.NotReadyState)
545-
close(p.outboundEvents)
546552
p.outboundEvents = nil
547553
p.inboundEvents = nil
548554
p.initialized = false

openfeature/multi/multiprovider_test.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -475,14 +475,8 @@ func TestMultiProvider_Track(t *testing.T) {
475475
require.NoError(t, err)
476476

477477
// Simulate error state for one provider
478-
errorProvider.eventChannel <- of.Event{
479-
ProviderName: "error-provider",
480-
EventType: of.ProviderError,
481-
ProviderEventDetails: of.ProviderEventDetails{
482-
Message: "error",
483-
EventMetadata: make(map[string]any),
484-
},
485-
}
478+
errorProvider.EmitEvent(of.ProviderError, "error")
479+
486480
// wait for event processing
487481
<-mp.outboundEvents
488482

0 commit comments

Comments
 (0)