diff --git a/common/resource/params.go b/common/resource/params.go index 633575420df..9d667d7b346 100644 --- a/common/resource/params.go +++ b/common/resource/params.go @@ -101,5 +101,11 @@ type ( // ShardDistributorMatchingConfig is the config for shard distributor executor client in matching service ShardDistributorMatchingConfig clientcommon.Config + + // DrainObserver is an optional observer that signals when this instance is + // drained from service discovery. + // It is used by shard-distributor executor clients to + // gracefully stop processing during drains. + DrainObserver clientcommon.DrainSignalObserver } ) diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index 6998fccca99..e028f843619 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -111,6 +111,7 @@ type ( timeSource clock.TimeSource failoverNotificationVersion int64 ShardDistributorMatchingConfig clientcommon.Config + drainObserver clientcommon.DrainSignalObserver } ) @@ -140,6 +141,7 @@ func NewEngine( timeSource clock.TimeSource, shardDistributorClient executorclient.Client, ShardDistributorMatchingConfig clientcommon.Config, + drainObserver clientcommon.DrainSignalObserver, ) Engine { e := &matchingEngineImpl{ taskListRegistry: tasklist.NewTaskListRegistry(metricsClient), @@ -161,6 +163,7 @@ func NewEngine( isolationState: isolationState, timeSource: timeSource, ShardDistributorMatchingConfig: ShardDistributorMatchingConfig, + drainObserver: drainObserver, } e.setupExecutor(shardDistributorClient) @@ -215,6 +218,7 @@ func (e *matchingEngineImpl) setupExecutor(shardDistributorExecutorClient execut "grpc": fmt.Sprintf("%d", e.config.RPCConfig.GRPCPort), "hostIP": hostIP.String(), }, + DrainObserver: e.drainObserver, } executor, err := executorclient.NewExecutor[tasklist.ShardProcessor](params) if err != nil { diff --git a/service/matching/handler/engine_integration_test.go b/service/matching/handler/engine_integration_test.go index 7649d2ddf60..c47e71a6027 100644 --- a/service/matching/handler/engine_integration_test.go +++ b/service/matching/handler/engine_integration_test.go @@ -194,6 +194,7 @@ func (s *matchingEngineSuite) newMatchingEngine( s.mockTimeSource, s.mockShardExecutorClient, defaultSDExecutorConfig(), + nil, ).(*matchingEngineImpl) } diff --git a/service/matching/handler/membership_test.go b/service/matching/handler/membership_test.go index 3be1cb422fb..644a45e1f9c 100644 --- a/service/matching/handler/membership_test.go +++ b/service/matching/handler/membership_test.go @@ -139,6 +139,7 @@ func TestGetTaskListManager_OwnerShip(t *testing.T) { mockTimeSource, mockShardDistributorExecutorClient, defaultSDExecutorConfig(), + nil, ).(*matchingEngineImpl) resolverMock.EXPECT().Lookup(gomock.Any(), gomock.Any()).Return( diff --git a/service/matching/service.go b/service/matching/service.go index 82512c6c918..782a2d5caee 100644 --- a/service/matching/service.go +++ b/service/matching/service.go @@ -45,6 +45,7 @@ type Service struct { stopC chan struct{} config *config.Config ShardDistributorMatchingConfig clientcommon.Config + drainObserver clientcommon.DrainSignalObserver } // NewService builds a new cadence-matching service @@ -84,6 +85,7 @@ func NewService( config: serviceConfig, stopC: make(chan struct{}), ShardDistributorMatchingConfig: params.ShardDistributorMatchingConfig, + drainObserver: params.DrainObserver, }, nil } @@ -111,6 +113,7 @@ func (s *Service) Start() { s.GetTimeSource(), s.GetShardDistributorExecutorClient(), s.ShardDistributorMatchingConfig, + s.drainObserver, ) s.handler = handler.NewHandler(engine, s.config, s.GetDomainCache(), s.GetMetricsClient(), s.GetLogger(), s.GetThrottledLogger())