Skip to content

Commit 08f4ade

Browse files
committed
improve event handling in multiprovider
- remove inboundEvents channel in favor of direct pipe channel - fix race condition in updateProviderStateFromEvent with lock - optimize Shutdown to only spawn goroutines for StateHandler providers - ensure outboundEvents channel is closed after all workers complete Signed-off-by: Roman Dmytrenko <[email protected]>
1 parent 6650376 commit 08f4ade

File tree

2 files changed

+82
-70
lines changed

2 files changed

+82
-70
lines changed

openfeature/multi/multiprovider.go

Lines changed: 80 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ type (
4343
strategyFunc StrategyFn[FlagTypes] // used for evaluating strategies
4444
logger *slog.Logger
4545
outboundEvents chan of.Event
46-
inboundEvents chan namedEvent
4746
workerGroup sync.WaitGroup
4847
shutdownFunc context.CancelFunc
4948
globalHooks []of.Hook
@@ -84,6 +83,12 @@ type (
8483
providers []*namedProvider
8584
customComparator Comparator
8685
}
86+
87+
// namedEventHandler is a wrapper around an [of.EventHandler] that includes the provider name.
88+
namedEventHandler struct {
89+
of.EventHandler
90+
name string
91+
}
8792
)
8893

8994
func (n *namedProvider) Name() string {
@@ -352,12 +357,7 @@ func (p *Provider) ObjectEvaluation(ctx context.Context, flag string, defaultVal
352357
func (p *Provider) Init(evalCtx of.EvaluationContext) error {
353358
var eg errgroup.Group
354359
// wrapper type used only for initialization of event listener workers
355-
type namedEventHandler struct {
356-
of.EventHandler
357-
name string
358-
}
359360
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "start initialization")
360-
p.inboundEvents = make(chan namedEvent, len(p.providers))
361361
handlers := make(chan namedEventHandler, len(p.providers))
362362
for _, provider := range p.providers {
363363
name := provider.Name()
@@ -402,62 +402,77 @@ func (p *Provider) Init(evalCtx of.EvaluationContext) error {
402402
}
403403
close(handlers)
404404
workerCtx, shutdownFunc := context.WithCancel(context.Background())
405-
for h := range handlers {
406-
go p.startListening(workerCtx, h.name, h.EventHandler, &p.workerGroup)
407-
}
408405
p.shutdownFunc = shutdownFunc
409406

410-
p.workerGroup.Add(1)
411-
go func(ctx context.Context) {
412-
workerLogger := p.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker"))
413-
defer p.workerGroup.Done()
414-
415-
for {
416-
select {
417-
case <-ctx.Done():
418-
return
419-
case e := <-p.inboundEvents:
420-
l := workerLogger.With(
421-
slog.String(MetadataProviderName, e.providerName),
422-
slog.String(MetadataProviderType, e.ProviderName),
423-
)
424-
l.LogAttrs(ctx, slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType)))
425-
if p.updateProviderStateFromEvent(e) {
426-
p.outboundEvents <- e.Event
427-
l.LogAttrs(ctx, slog.LevelDebug, "forwarded state update event")
428-
} else {
429-
l.LogAttrs(ctx, slog.LevelDebug, "total state not updated, inbound event will not be emitted")
430-
}
431-
}
432-
}
433-
}(workerCtx)
407+
if len(handlers) > 0 {
408+
go p.forwardProviderEvents(workerCtx, handlers)
409+
} else {
410+
// we don't emit any events so we can just close the channel
411+
close(p.outboundEvents)
412+
}
434413

435414
p.setStatus(of.ReadyState)
436415
p.initialized = true
437416
return nil
438417
}
439418

440-
// startListening is intended to be called on a per-provider basis as a goroutine to listen to events from a provider
441-
// implementing [of.EventHandler].
442-
func (p *Provider) startListening(ctx context.Context, name string, h of.EventHandler, wg *sync.WaitGroup) {
443-
wg.Add(1)
444-
defer wg.Done()
445-
for {
446-
select {
447-
case e := <-h.EventChannel():
448-
if e.EventMetadata == nil {
449-
e.EventMetadata = make(map[string]any)
450-
}
451-
e.EventMetadata[MetadataProviderName] = name
452-
if p, ok := h.(of.FeatureProvider); ok {
453-
e.EventMetadata[MetadataProviderType] = p.Metadata().Name
454-
}
455-
p.inboundEvents <- namedEvent{
456-
Event: e,
457-
providerName: name,
419+
// forwardProviderEvents establishes an event forwarding pipeline that collects events from multiple provider
420+
// event handlers and forwards them to the multiprovider's outbound event channel. It spawns a goroutine for
421+
// each provider handler to listen for events, aggregates them through an internal pipe, and selectively forwards
422+
// events that result in state changes. The function blocks until workerCtx is cancelled or all provider event
423+
// channels are closed, ensuring proper cleanup by closing the outbound channel when complete.
424+
func (p *Provider) forwardProviderEvents(workerCtx context.Context, handlers chan namedEventHandler) {
425+
p.workerGroup.Add(1)
426+
defer p.workerGroup.Done()
427+
defer close(p.outboundEvents)
428+
429+
workerLogger := p.logger.With(slog.String("multiprovider-worker", "event-forwarder-worker"))
430+
pipe := make(chan namedEvent)
431+
var wg sync.WaitGroup
432+
for ch := range handlers {
433+
wg.Add(1)
434+
go func(ctx context.Context, h of.EventHandler, name string, out chan<- namedEvent) {
435+
defer wg.Done()
436+
for {
437+
select {
438+
case <-ctx.Done():
439+
return
440+
case e, ok := <-h.EventChannel():
441+
if !ok {
442+
return
443+
}
444+
if e.EventMetadata == nil {
445+
e.EventMetadata = make(map[string]any)
446+
}
447+
e.EventMetadata[MetadataProviderName] = name
448+
if p, ok := h.(of.FeatureProvider); ok {
449+
e.EventMetadata[MetadataProviderType] = p.Metadata().Name
450+
}
451+
out <- namedEvent{
452+
Event: e,
453+
providerName: name,
454+
}
455+
}
458456
}
459-
case <-ctx.Done():
460-
return
457+
}(workerCtx, ch.EventHandler, ch.name, pipe)
458+
}
459+
460+
go func() {
461+
wg.Wait()
462+
close(pipe)
463+
}()
464+
465+
for e := range pipe {
466+
l := workerLogger.With(
467+
slog.String(MetadataProviderName, e.providerName),
468+
slog.String(MetadataProviderType, e.ProviderName),
469+
)
470+
l.LogAttrs(workerCtx, slog.LevelDebug, "received event from provider", slog.String("event-type", string(e.EventType)))
471+
if p.updateProviderStateFromEvent(e) {
472+
p.outboundEvents <- e.Event
473+
l.LogAttrs(workerCtx, slog.LevelDebug, "forwarded state update event")
474+
} else {
475+
l.LogAttrs(workerCtx, slog.LevelDebug, "total state not updated, inbound event will not be emitted")
461476
}
462477
}
463478
}
@@ -483,7 +498,10 @@ func (p *Provider) updateProviderStateFromEvent(e namedEvent) bool {
483498
if e.EventType == of.ProviderConfigChange {
484499
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "ProviderConfigChange event", slog.String("event-message", e.Message))
485500
}
486-
logProviderState(p.logger, e, p.providerStatus[e.providerName])
501+
p.providerStatusLock.Lock()
502+
previousState := p.providerStatus[e.providerName]
503+
p.providerStatusLock.Unlock()
504+
logProviderState(p.logger, e, previousState)
487505
return p.updateProviderState(e.providerName, eventTypeToState[e.EventType])
488506
}
489507

@@ -530,14 +548,13 @@ func (p *Provider) Shutdown() {
530548

531549
var wg sync.WaitGroup
532550
for _, provider := range p.providers {
533-
wg.Add(1)
534-
535-
go func(p NamedProvider) {
536-
defer wg.Done()
537-
if stateHandle, ok := p.unwrap().(of.StateHandler); ok {
538-
stateHandle.Shutdown()
539-
}
540-
}(provider)
551+
if stateHandle, ok := provider.unwrap().(of.StateHandler); ok {
552+
wg.Add(1)
553+
go func(p of.StateHandler) {
554+
defer wg.Done()
555+
p.Shutdown()
556+
}(stateHandle)
557+
}
541558
}
542559

543560
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "waiting for provider shutdown completion")
@@ -547,15 +564,10 @@ func (p *Provider) Shutdown() {
547564
// Wait for workers to stop
548565
p.workerGroup.Wait()
549566
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "worker shutdown completed")
550-
close(p.inboundEvents)
551567
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "starting provider shutdown")
552-
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "provider shutdown completed")
553-
close(p.outboundEvents)
554568
p.setStatus(of.NotReadyState)
555-
556-
p.outboundEvents = nil
557-
p.inboundEvents = nil
558569
p.initialized = false
570+
p.logger.LogAttrs(context.Background(), slog.LevelDebug, "provider shutdown completed")
559571
}
560572

561573
// Status provides the current state of the [multi.Provider].

openfeature/multi/multiprovider_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ func TestMultiProvider_Track(t *testing.T) {
454454
// wait for event processing
455455
require.Eventually(t, func() bool {
456456
select {
457-
case <-mp.outboundEvents:
458-
return true
457+
case e := <-mp.outboundEvents:
458+
return e.ProviderName == "error-provider" && e.EventType == of.ProviderError
459459
default:
460460
return false
461461
}

0 commit comments

Comments
 (0)