diff --git a/service/sharddistributor/client/clientcommon/drain_observer.go b/service/sharddistributor/client/clientcommon/drain_observer.go new file mode 100644 index 00000000000..f908dc2f021 --- /dev/null +++ b/service/sharddistributor/client/clientcommon/drain_observer.go @@ -0,0 +1,20 @@ +package clientcommon + +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination drain_observer_mock.go . DrainSignalObserver + +// DrainSignalObserver observes infrastructure drain signals. +// Drain is reversible: if the instance reappears in discovery, +// Undrain() fires, allowing the consumer to resume operations. +// +// Implementations use close-to-broadcast semantics: the returned channel is +// closed when the event occurs, so all goroutines selecting on it wake up. +// After each close, a fresh channel is created for the next cycle. +type DrainSignalObserver interface { + // Drain returns a channel closed when the instance is + // removed from service discovery. + Drain() <-chan struct{} + + // Undrain returns a channel closed when the instance is + // added back to service discovery after a drain. + Undrain() <-chan struct{} +} diff --git a/service/sharddistributor/client/clientcommon/drain_observer_mock.go b/service/sharddistributor/client/clientcommon/drain_observer_mock.go new file mode 100644 index 00000000000..3b55e273580 --- /dev/null +++ b/service/sharddistributor/client/clientcommon/drain_observer_mock.go @@ -0,0 +1,68 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: drain_observer.go +// +// Generated by this command: +// +// mockgen -package clientcommon -source drain_observer.go -destination drain_observer_mock.go . DrainSignalObserver +// + +// Package clientcommon is a generated GoMock package. +package clientcommon + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockDrainSignalObserver is a mock of DrainSignalObserver interface. +type MockDrainSignalObserver struct { + ctrl *gomock.Controller + recorder *MockDrainSignalObserverMockRecorder + isgomock struct{} +} + +// MockDrainSignalObserverMockRecorder is the mock recorder for MockDrainSignalObserver. +type MockDrainSignalObserverMockRecorder struct { + mock *MockDrainSignalObserver +} + +// NewMockDrainSignalObserver creates a new mock instance. +func NewMockDrainSignalObserver(ctrl *gomock.Controller) *MockDrainSignalObserver { + mock := &MockDrainSignalObserver{ctrl: ctrl} + mock.recorder = &MockDrainSignalObserverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDrainSignalObserver) EXPECT() *MockDrainSignalObserverMockRecorder { + return m.recorder +} + +// Drain mocks base method. +func (m *MockDrainSignalObserver) Drain() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Drain") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// Drain indicates an expected call of Drain. +func (mr *MockDrainSignalObserverMockRecorder) Drain() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Drain", reflect.TypeOf((*MockDrainSignalObserver)(nil).Drain)) +} + +// Undrain mocks base method. +func (m *MockDrainSignalObserver) Undrain() <-chan struct{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Undrain") + ret0, _ := ret[0].(<-chan struct{}) + return ret0 +} + +// Undrain indicates an expected call of Undrain. +func (mr *MockDrainSignalObserverMockRecorder) Undrain() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Undrain", reflect.TypeOf((*MockDrainSignalObserver)(nil).Undrain)) +} diff --git a/service/sharddistributor/client/executorclient/client.go b/service/sharddistributor/client/executorclient/client.go index 35eb533615c..3a3a1792154 100644 --- a/service/sharddistributor/client/executorclient/client.go +++ b/service/sharddistributor/client/executorclient/client.go @@ -73,7 +73,8 @@ type Params[SP ShardProcessor] struct { ShardProcessorFactory ShardProcessorFactory[SP] Config clientcommon.Config TimeSource clock.TimeSource - Metadata ExecutorMetadata `optional:"true"` + Metadata ExecutorMetadata `optional:"true"` + DrainObserver clientcommon.DrainSignalObserver `optional:"true"` } // NewExecutorWithNamespace creates an executor for a specific namespace @@ -136,6 +137,7 @@ func newExecutorWithConfig[SP ShardProcessor](params Params[SP], namespaceConfig metadata: syncExecutorMetadata{ data: params.Metadata, }, + drainObserver: params.DrainObserver, } executor.setMigrationMode(namespaceConfig.GetMigrationMode()) diff --git a/service/sharddistributor/client/executorclient/clientimpl.go b/service/sharddistributor/client/executorclient/clientimpl.go index 9e4640e7a53..4588365a0f6 100644 --- a/service/sharddistributor/client/executorclient/clientimpl.go +++ b/service/sharddistributor/client/executorclient/clientimpl.go @@ -16,6 +16,7 @@ import ( "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/types" + "github.com/uber/cadence/service/sharddistributor/client/clientcommon" "github.com/uber/cadence/service/sharddistributor/client/executorclient/metricsconstants" "github.com/uber/cadence/service/sharddistributor/client/executorclient/syncgeneric" ) @@ -104,6 +105,7 @@ type executorImpl[SP ShardProcessor] struct { metrics tally.Scope migrationMode atomic.Int32 metadata syncExecutorMetadata + drainObserver clientcommon.DrainSignalObserver } func (e *executorImpl[SP]) setMigrationMode(mode types.MigrationMode) { @@ -202,6 +204,11 @@ func (e *executorImpl[SP]) heartbeatloop(ctx context.Context) { heartBeatTimer := e.timeSource.NewTimer(backoff.JitDuration(e.heartBeatInterval, heartbeatJitterCoeff)) defer heartBeatTimer.Stop() + var drainCh <-chan struct{} + if e.drainObserver != nil { + drainCh = e.drainObserver.Drain() + } + for { select { case <-ctx.Done(): @@ -214,6 +221,18 @@ func (e *executorImpl[SP]) heartbeatloop(ctx context.Context) { e.stopShardProcessors() e.sendDrainingHeartbeat() return + case <-drainCh: + e.logger.Info("drain signal received, stopping shard processors") + e.stopShardProcessors() + e.sendDrainingHeartbeat() + + if !e.waitForUndrain(ctx) { + return + } + + e.logger.Info("undrain signal received, resuming heartbeat") + drainCh = e.drainObserver.Drain() + heartBeatTimer.Reset(backoff.JitDuration(e.heartBeatInterval, heartbeatJitterCoeff)) case <-heartBeatTimer.Chan(): heartBeatTimer.Reset(backoff.JitDuration(e.heartBeatInterval, heartbeatJitterCoeff)) err := e.heartbeatAndUpdateAssignment(ctx) @@ -229,6 +248,25 @@ func (e *executorImpl[SP]) heartbeatloop(ctx context.Context) { } } +// waitForUndrain blocks until the undrain signal fires or the executor is stopped. +// Returns true if undrained (caller should resume), false if stopped. +func (e *executorImpl[SP]) waitForUndrain(ctx context.Context) bool { + if e.drainObserver == nil { + return false + } + + undrainCh := e.drainObserver.Undrain() + + select { + case <-ctx.Done(): + return false + case <-e.stopC: + return false + case <-undrainCh: + return true + } +} + func (e *executorImpl[SP]) heartbeatAndUpdateAssignment(ctx context.Context) error { if !e.assignmentMutex.TryLock() { e.logger.Error("still doing assignment, skipping heartbeat") diff --git a/service/sharddistributor/leader/namespace/manager.go b/service/sharddistributor/leader/namespace/manager.go index 7fb57e437c7..81815e02e1d 100644 --- a/service/sharddistributor/leader/namespace/manager.go +++ b/service/sharddistributor/leader/namespace/manager.go @@ -9,6 +9,7 @@ import ( "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/service/sharddistributor/client/clientcommon" "github.com/uber/cadence/service/sharddistributor/config" "github.com/uber/cadence/service/sharddistributor/leader/election" ) @@ -19,21 +20,27 @@ var Module = fx.Module( fx.Invoke(NewManager), ) +// stateFn represents a state in the election state machine. +// Each state is a function that blocks until a transition occurs +// and returns the next state function, or nil to stop. +type stateFn func(ctx context.Context) stateFn + type Manager struct { cfg config.ShardDistribution logger log.Logger electionFactory election.Factory + drainObserver clientcommon.DrainSignalObserver namespaces map[string]*namespaceHandler ctx context.Context cancel context.CancelFunc } type namespaceHandler struct { - logger log.Logger - elector election.Elector - cancel context.CancelFunc - namespaceCfg config.Namespace - cleanupWg sync.WaitGroup + logger log.Logger + electionFactory election.Factory + namespaceCfg config.Namespace + drainObserver clientcommon.DrainSignalObserver + cleanupWg sync.WaitGroup } type ManagerParams struct { @@ -43,6 +50,7 @@ type ManagerParams struct { Logger log.Logger ElectionFactory election.Factory Lifecycle fx.Lifecycle + DrainObserver clientcommon.DrainSignalObserver `optional:"true"` } // NewManager creates a new namespace manager @@ -51,6 +59,7 @@ func NewManager(p ManagerParams) *Manager { cfg: p.Cfg, logger: p.Logger.WithTags(tag.ComponentNamespaceManager), electionFactory: p.ElectionFactory, + drainObserver: p.DrainObserver, namespaces: make(map[string]*namespaceHandler), } @@ -73,7 +82,9 @@ func (m *Manager) Start(ctx context.Context) error { return nil } -// Stop gracefully stops all namespace handlers +// Stop gracefully stops all namespace handlers. +// Cancels the manager context which cascades to all handler contexts, +// then waits for all election goroutines to finish. func (m *Manager) Stop(ctx context.Context) error { if m.cancel == nil { return fmt.Errorf("manager was not running") @@ -82,69 +93,110 @@ func (m *Manager) Stop(ctx context.Context) error { m.cancel() for ns, handler := range m.namespaces { - m.logger.Info("Stopping namespace handler", tag.ShardNamespace(ns)) - if handler.cancel != nil { - handler.cancel() - } + m.logger.Info("Waiting for namespace handler to stop", tag.ShardNamespace(ns)) + handler.cleanupWg.Wait() } return nil } -// handleNamespace sets up leadership election for a namespace +// handleNamespace sets up a namespace handler and starts its election goroutine. func (m *Manager) handleNamespace(namespaceCfg config.Namespace) error { if _, exists := m.namespaces[namespaceCfg.Name]; exists { return fmt.Errorf("namespace %s already running", namespaceCfg.Name) } - m.logger.Info("Setting up namespace handler", tag.ShardNamespace(namespaceCfg.Name)) - - ctx, cancel := context.WithCancel(m.ctx) - - // Create elector for this namespace - elector, err := m.electionFactory.CreateElector(ctx, namespaceCfg) - if err != nil { - cancel() - return err - } - handler := &namespaceHandler{ - logger: m.logger.WithTags(tag.ShardNamespace(namespaceCfg.Name)), - elector: elector, - } - // cancel cancels the context and ensures that electionRunner is stopped. - handler.cancel = func() { - cancel() - handler.cleanupWg.Wait() + logger: m.logger.WithTags(tag.ShardNamespace(namespaceCfg.Name)), + electionFactory: m.electionFactory, + namespaceCfg: namespaceCfg, + drainObserver: m.drainObserver, } m.namespaces[namespaceCfg.Name] = handler handler.cleanupWg.Add(1) - // Start leadership election - go handler.runElection(ctx) + + go handler.runElection(m.ctx) return nil } -// runElection manages the leadership election for a namespace -func (handler *namespaceHandler) runElection(ctx context.Context) { - defer handler.cleanupWg.Done() +// runElection drives the election state machine for a namespace. +// It starts in the campaigning state and follows state transitions +// until a state returns nil (stop). +func (h *namespaceHandler) runElection(ctx context.Context) { + defer h.cleanupWg.Done() - handler.logger.Info("Starting election for namespace") + for state := h.campaigning; state != nil; { + state = state(ctx) + } +} - leaderCh := handler.elector.Run(ctx) +// campaigning creates an elector and participates in leader election. +// Transitions: h.idle on drain, h.campaigning on recoverable error, nil on stop. +func (h *namespaceHandler) campaigning(ctx context.Context) stateFn { + h.logger.Info("Entering campaigning state") + + var drainCh <-chan struct{} + if h.drainObserver != nil { + drainCh = h.drainObserver.Drain() + } + + // Check if already drained before creating an elector. + select { + case <-drainCh: + h.logger.Info("Drain signal detected before election start") + return h.idle + default: + } + + electorCtx, cancel := context.WithCancel(ctx) + defer cancel() + + elector, err := h.electionFactory.CreateElector(electorCtx, h.namespaceCfg) + if err != nil { + h.logger.Error("Failed to create elector", tag.Error(err)) + return nil + } + + leaderCh := elector.Run(electorCtx) for { select { case <-ctx.Done(): - handler.logger.Info("Context cancelled, stopping election") - return - case isLeader := <-leaderCh: + return nil + case <-drainCh: + h.logger.Info("Drain signal received, resigning from election") + return h.idle + case isLeader, ok := <-leaderCh: + if !ok { + h.logger.Error("Election channel closed unexpectedly") + return h.campaigning + } if isLeader { - handler.logger.Info("Became leader for namespace") + h.logger.Info("Became leader for namespace") } else { - handler.logger.Info("Lost leadership for namespace") + h.logger.Info("Lost leadership for namespace") } } } } + +// idle waits for an undrain signal to resume campaigning. +// Transitions: h.campaigning on undrain, nil on stop. +func (h *namespaceHandler) idle(ctx context.Context) stateFn { + h.logger.Info("Entering idle state (drained)") + + var undrainCh <-chan struct{} + if h.drainObserver != nil { + undrainCh = h.drainObserver.Undrain() + } + + select { + case <-ctx.Done(): + return nil + case <-undrainCh: + h.logger.Info("Undrain signal received, resuming election") + return h.campaigning + } +} diff --git a/service/sharddistributor/leader/namespace/manager_test.go b/service/sharddistributor/leader/namespace/manager_test.go index df398f45295..542147ea375 100644 --- a/service/sharddistributor/leader/namespace/manager_test.go +++ b/service/sharddistributor/leader/namespace/manager_test.go @@ -3,6 +3,7 @@ package namespace import ( "context" "errors" + "sync" "testing" "time" @@ -16,8 +17,61 @@ import ( "github.com/uber/cadence/service/sharddistributor/leader/election" ) +// mockElectorRun returns a DoAndReturn function that simulates an elector: +// it returns a leaderCh and closes it when the context is cancelled. +func mockElectorRun(leaderCh chan bool) func(ctx context.Context) <-chan bool { + return func(ctx context.Context) <-chan bool { + go func() { + <-ctx.Done() + close(leaderCh) + }() + return (<-chan bool)(leaderCh) + } +} + +// closeDrainObserver is a test helper that simulates the close-and-recreate +// semantics of the real DrainSignalObserver. It wraps a mock and manages +// the channel lifecycle. +type closeDrainObserver struct { + mu sync.Mutex + drainCh chan struct{} + undrainCh chan struct{} +} + +func newCloseDrainObserver() *closeDrainObserver { + return &closeDrainObserver{ + drainCh: make(chan struct{}), + undrainCh: make(chan struct{}), + } +} + +func (o *closeDrainObserver) Drain() <-chan struct{} { + o.mu.Lock() + defer o.mu.Unlock() + return o.drainCh +} + +func (o *closeDrainObserver) Undrain() <-chan struct{} { + o.mu.Lock() + defer o.mu.Unlock() + return o.undrainCh +} + +func (o *closeDrainObserver) SignalDrain() { + o.mu.Lock() + defer o.mu.Unlock() + close(o.drainCh) + o.undrainCh = make(chan struct{}) +} + +func (o *closeDrainObserver) SignalUndrain() { + o.mu.Lock() + defer o.mu.Unlock() + close(o.undrainCh) + o.drainCh = make(chan struct{}) +} + func TestNewManager(t *testing.T) { - // Setup logger := testlogger.New(t) ctrl := gomock.NewController(t) electionFactory := election.NewMockFactory(ctrl) @@ -28,7 +82,6 @@ func TestNewManager(t *testing.T) { }, } - // Test manager := NewManager(ManagerParams{ Cfg: cfg, Logger: logger, @@ -36,23 +89,20 @@ func TestNewManager(t *testing.T) { Lifecycle: fxtest.NewLifecycle(t), }) - // Assert assert.NotNil(t, manager) assert.Equal(t, cfg, manager.cfg) assert.Equal(t, 0, len(manager.namespaces)) } func TestStartManager(t *testing.T) { - // Setup logger := testlogger.New(t) ctrl := gomock.NewController(t) electionFactory := election.NewMockFactory(ctrl) elector := election.NewMockElector(ctrl) - electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) - leaderCh := make(chan bool) - elector.EXPECT().Run(gomock.Any()).Return((<-chan bool)(leaderCh)) + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) + elector.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh)) cfg := config.ShardDistribution{ Namespaces: []config.Namespace{ @@ -67,22 +117,21 @@ func TestStartManager(t *testing.T) { namespaces: make(map[string]*namespaceHandler), } - // Test err := manager.Start(context.Background()) + time.Sleep(10 * time.Millisecond) - // Try to give goroutine time to start. - time.Sleep(time.Millisecond) - - // Assert assert.NoError(t, err) assert.NotNil(t, manager.ctx) assert.NotNil(t, manager.cancel) assert.Equal(t, 1, len(manager.namespaces)) assert.Contains(t, manager.namespaces, "test-namespace") + + // Cleanup + manager.cancel() + manager.namespaces["test-namespace"].cleanupWg.Wait() } func TestStartManagerWithElectorError(t *testing.T) { - // Setup logger := testlogger.New(t) ctrl := gomock.NewController(t) electionFactory := election.NewMockFactory(ctrl) @@ -103,26 +152,26 @@ func TestStartManagerWithElectorError(t *testing.T) { namespaces: make(map[string]*namespaceHandler), } - // Test err := manager.Start(context.Background()) + assert.NoError(t, err) - // Assert - assert.Error(t, err) - assert.Equal(t, expectedErr, err) - assert.Equal(t, 0, len(manager.namespaces)) + // The goroutine exits on elector creation error + handler := manager.namespaces["test-namespace"] + handler.cleanupWg.Wait() + + // Cleanup + manager.cancel() } func TestStopManager(t *testing.T) { - // Setup logger := testlogger.New(t) ctrl := gomock.NewController(t) electionFactory := election.NewMockFactory(ctrl) elector := election.NewMockElector(ctrl) - electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) - leaderCh := make(chan bool) - elector.EXPECT().Run(gomock.Any()).Return((<-chan bool)(leaderCh)) + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) + elector.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh)) cfg := config.ShardDistribution{ Namespaces: []config.Namespace{ @@ -137,25 +186,17 @@ func TestStopManager(t *testing.T) { namespaces: make(map[string]*namespaceHandler), } - // Start the manager first _ = manager.Start(context.Background()) + time.Sleep(10 * time.Millisecond) - // Try to give goroutine time to start. - time.Sleep(time.Millisecond) - - // Test err := manager.Stop(context.Background()) - - // Assert assert.NoError(t, err) } func TestHandleNamespaceAlreadyExists(t *testing.T) { - // Setup logger := testlogger.New(t) ctrl := gomock.NewController(t) electionFactory := election.NewMockFactory(ctrl) - mockElector := election.NewMockElector(ctrl) manager := &Manager{ cfg: config.ShardDistribution{}, @@ -164,32 +205,62 @@ func TestHandleNamespaceAlreadyExists(t *testing.T) { namespaces: make(map[string]*namespaceHandler), } - // Set context manager.ctx, manager.cancel = context.WithCancel(context.Background()) + defer manager.cancel() - // Add existing namespace handler - manager.namespaces["test-namespace"] = &namespaceHandler{ - elector: mockElector, - } + manager.namespaces["test-namespace"] = &namespaceHandler{} - // Test err := manager.handleNamespace(config.Namespace{Name: "test-namespace"}) - - // Assert assert.ErrorContains(t, err, "namespace test-namespace already running") } -func TestRunElection(t *testing.T) { - // Setup +func TestRunElection_LeadershipEvents(t *testing.T) { logger := testlogger.New(t) ctrl := gomock.NewController(t) electionFactory := election.NewMockFactory(ctrl) elector := election.NewMockElector(ctrl) + leaderCh := make(chan bool) electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) + elector.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh)) + + cfg := config.ShardDistribution{ + Namespaces: []config.Namespace{ + {Name: "test-namespace"}, + }, + } + + manager := &Manager{ + cfg: cfg, + logger: logger, + electionFactory: electionFactory, + namespaces: make(map[string]*namespaceHandler), + } + + err := manager.Start(context.Background()) + require.NoError(t, err) + + leaderCh <- true + time.Sleep(10 * time.Millisecond) + + leaderCh <- false + time.Sleep(10 * time.Millisecond) + + err = manager.Stop(context.Background()) + assert.NoError(t, err) +} + +func TestDrainSignal_TriggersResign(t *testing.T) { + logger := testlogger.New(t) + ctrl := gomock.NewController(t) + electionFactory := election.NewMockFactory(ctrl) + elector := election.NewMockElector(ctrl) leaderCh := make(chan bool) - elector.EXPECT().Run(gomock.Any()).Return((<-chan bool)(leaderCh)) + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) + elector.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh)) + + observer := newCloseDrainObserver() cfg := config.ShardDistribution{ Namespaces: []config.Namespace{ @@ -201,25 +272,144 @@ func TestRunElection(t *testing.T) { cfg: cfg, logger: logger, electionFactory: electionFactory, + drainObserver: observer, namespaces: make(map[string]*namespaceHandler), } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Start the test goroutine - err := manager.Start(ctx) + err := manager.Start(context.Background()) require.NoError(t, err) - // Test becoming leader + // Wait for elector to be running leaderCh <- true - time.Sleep(10 * time.Millisecond) // Give some time for goroutine to process + time.Sleep(10 * time.Millisecond) - // Test losing leadership - leaderCh <- false - time.Sleep(10 * time.Millisecond) // Give some time for goroutine to process + // Close drain channel — all handlers see it + observer.SignalDrain() + time.Sleep(50 * time.Millisecond) - // Cancel context to end the goroutine - manager.cancel() - time.Sleep(10 * time.Millisecond) // Give some time for goroutine to exit + // Handler should be in idle state + err = manager.Stop(context.Background()) + assert.NoError(t, err) +} + +func TestDrainSignal_NilDrainObserver(t *testing.T) { + logger := testlogger.New(t) + ctrl := gomock.NewController(t) + electionFactory := election.NewMockFactory(ctrl) + elector := election.NewMockElector(ctrl) + + leaderCh := make(chan bool) + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) + elector.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh)) + + cfg := config.ShardDistribution{ + Namespaces: []config.Namespace{ + {Name: "test-namespace"}, + }, + } + + manager := &Manager{ + cfg: cfg, + logger: logger, + electionFactory: electionFactory, + namespaces: make(map[string]*namespaceHandler), + } + + err := manager.Start(context.Background()) + require.NoError(t, err) + + assert.Nil(t, manager.drainObserver) + + err = manager.Stop(context.Background()) + assert.NoError(t, err) +} + +func TestDrainSignal_ManagerStopsBeforeDrain(t *testing.T) { + logger := testlogger.New(t) + ctrl := gomock.NewController(t) + electionFactory := election.NewMockFactory(ctrl) + elector := election.NewMockElector(ctrl) + + leaderCh := make(chan bool) + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector, nil) + elector.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh)) + + observer := newCloseDrainObserver() + + cfg := config.ShardDistribution{ + Namespaces: []config.Namespace{ + {Name: "test-namespace"}, + }, + } + + manager := &Manager{ + cfg: cfg, + logger: logger, + electionFactory: electionFactory, + drainObserver: observer, + namespaces: make(map[string]*namespaceHandler), + } + + err := manager.Start(context.Background()) + require.NoError(t, err) + + // Stop before drain fires + err = manager.Stop(context.Background()) + assert.NoError(t, err) +} + +func TestDrainThenUndrain_ResumesElection(t *testing.T) { + logger := testlogger.New(t) + ctrl := gomock.NewController(t) + electionFactory := election.NewMockFactory(ctrl) + + elector1 := election.NewMockElector(ctrl) + leaderCh1 := make(chan bool) + elector2 := election.NewMockElector(ctrl) + leaderCh2 := make(chan bool) + + gomock.InOrder( + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector1, nil), + electionFactory.EXPECT().CreateElector(gomock.Any(), gomock.Any()).Return(elector2, nil), + ) + elector1.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh1)) + elector2.EXPECT().Run(gomock.Any()).DoAndReturn(mockElectorRun(leaderCh2)) + + observer := newCloseDrainObserver() + + cfg := config.ShardDistribution{ + Namespaces: []config.Namespace{ + {Name: "test-namespace"}, + }, + } + + manager := &Manager{ + cfg: cfg, + logger: logger, + electionFactory: electionFactory, + drainObserver: observer, + namespaces: make(map[string]*namespaceHandler), + } + + err := manager.Start(context.Background()) + require.NoError(t, err) + + // Phase 1: elector1 running, become leader + leaderCh1 <- true + time.Sleep(10 * time.Millisecond) + + // Drain — elector1 resigns + observer.SignalDrain() + time.Sleep(50 * time.Millisecond) + + // Undrain — elector2 created, campaign again + observer.SignalUndrain() + time.Sleep(50 * time.Millisecond) + + // Phase 1 again: elector2 running, become leader + leaderCh2 <- true + time.Sleep(10 * time.Millisecond) + + err = manager.Stop(context.Background()) + assert.NoError(t, err) }