diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index c0b3c1695bc..4d86251436a 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -27,7 +27,6 @@ import ( "context" "errors" "fmt" - "math" "sync" "time" @@ -89,6 +88,8 @@ type ( } matchingEngineImpl struct { + taskListCreationLock sync.Mutex + taskListsRegistry tasklist.ManagerRegistry shutdownCompletion *sync.WaitGroup shutdown chan struct{} taskManager persistence.TaskManager @@ -99,8 +100,6 @@ type ( logger log.Logger metricsClient metrics.Client metricsScope tally.Scope - taskListsLock sync.RWMutex // locks mutation of taskLists - taskLists map[tasklist.Identifier]tasklist.Manager // Convert to LRU cache executor executorclient.Executor[tasklist.ShardProcessor] taskListsFactory *tasklist.ShardProcessorFactory config *config.Config @@ -143,13 +142,13 @@ func NewEngine( ShardDistributorMatchingConfig clientcommon.Config, ) Engine { e := &matchingEngineImpl{ + taskListsRegistry: tasklist.NewManagerRegistry(metricsClient), shutdown: make(chan struct{}), shutdownCompletion: &sync.WaitGroup{}, taskManager: taskManager, clusterMetadata: clusterMetadata, historyService: historyService, tokenSerializer: common.NewJSONTaskTokenSerializer(), - taskLists: make(map[tasklist.Identifier]tasklist.Manager), logger: logger.WithTags(tag.ComponentMatchingEngine), metricsClient: metricsClient, metricsScope: metricsScope, @@ -180,7 +179,7 @@ func (e *matchingEngineImpl) Stop() { close(e.shutdown) e.executor.Stop() // Executes Stop() on each task list outside of lock - for _, l := range e.getTaskLists(math.MaxInt32) { + for _, l := range e.taskListsRegistry.AllManagers() { l.Stop() } e.unregisterDomainFailoverCallback() @@ -191,10 +190,9 @@ func (e *matchingEngineImpl) setupExecutor(shardDistributorExecutorClient execut cfg, reportTTL := e.getValidatedShardDistributorConfig() taskListFactory := &tasklist.ShardProcessorFactory{ - TaskListsLock: &e.taskListsLock, - TaskLists: e.taskLists, - ReportTTL: reportTTL, - TimeSource: e.timeSource, + TaskListsRegistry: e.taskListsRegistry, + ReportTTL: reportTTL, + TimeSource: e.timeSource, } e.taskListsFactory = taskListFactory @@ -241,26 +239,15 @@ func (e *matchingEngineImpl) getValidatedShardDistributorConfig() (clientcommon. return cfg, reportTTL } -func (e *matchingEngineImpl) getTaskLists(maxCount int) []tasklist.Manager { - e.taskListsLock.RLock() - defer e.taskListsLock.RUnlock() - lists := make([]tasklist.Manager, 0, len(e.taskLists)) - count := 0 - for _, tlMgr := range e.taskLists { - lists = append(lists, tlMgr) - count++ - if count >= maxCount { - break - } - } - return lists -} - func (e *matchingEngineImpl) String() string { // Executes taskList.String() on each task list outside of lock buf := new(bytes.Buffer) - for _, l := range e.getTaskLists(1000) { - fmt.Fprintf(buf, "\n%s", l.String()) + + for i, tl := range e.taskListsRegistry.AllManagers() { + if i >= 1000 { + break + } + fmt.Fprintf(buf, "\n%s", tl.String()) } return buf.String() } @@ -274,12 +261,9 @@ func (e *matchingEngineImpl) getOrCreateTaskListManager(ctx context.Context, tas if sp != nil { // The first check is an optimization so almost all requests will have a task list manager // and return avoiding the write lock - e.taskListsLock.RLock() - if result, ok := e.taskLists[*taskList]; ok { - e.taskListsLock.RUnlock() + if result, ok := e.taskListsRegistry.ManagerByTaskListIdentifier(*taskList); ok { return result, nil } - e.taskListsLock.RUnlock() } err := e.errIfShardOwnershipLost(ctx, taskList) if err != nil { @@ -287,9 +271,9 @@ func (e *matchingEngineImpl) getOrCreateTaskListManager(ctx context.Context, tas } // If it gets here, write lock and check again in case a task list is created between the two locks - e.taskListsLock.Lock() - if result, ok := e.taskLists[*taskList]; ok { - e.taskListsLock.Unlock() + e.taskListCreationLock.Lock() + if result, ok := e.taskListsRegistry.ManagerByTaskListIdentifier(*taskList); ok { + e.taskListCreationLock.Unlock() return result, nil } @@ -309,7 +293,7 @@ func (e *matchingEngineImpl) getOrCreateTaskListManager(ctx context.Context, tas ClusterMetadata: e.clusterMetadata, IsolationState: e.isolationState, MatchingClient: e.matchingClient, - Registry: e, // Engine implements ManagerRegistry + Registry: e.taskListsRegistry, TaskList: taskList, TaskListKind: taskListKind, Cfg: e.config, @@ -319,17 +303,14 @@ func (e *matchingEngineImpl) getOrCreateTaskListManager(ctx context.Context, tas } mgr, err := tasklist.NewManager(params) if err != nil { - e.taskListsLock.Unlock() + e.taskListCreationLock.Unlock() logger.Info("Task list manager state changed", tag.LifeCycleStartFailed, tag.Error(err)) return nil, err } - e.taskLists[*taskList] = mgr - e.metricsClient.Scope(metrics.MatchingTaskListMgrScope).UpdateGauge( - metrics.TaskListManagersGauge, - float64(len(e.taskLists)), - ) - e.taskListsLock.Unlock() + e.taskListsRegistry.Register(*taskList, mgr) + e.taskListCreationLock.Unlock() + err = mgr.Start(context.Background()) if err != nil { logger.Info("Task list manager state changed", tag.LifeCycleStartFailed, tag.Error(err)) @@ -360,44 +341,6 @@ func (e *matchingEngineImpl) getOrCreateTaskListManager(ctx context.Context, tas return mgr, nil } -func (e *matchingEngineImpl) getTaskListByDomainLocked(domainID string, taskListKind *types.TaskListKind) *types.GetTaskListsByDomainResponse { - decisionTaskListMap := make(map[string]*types.DescribeTaskListResponse) - activityTaskListMap := make(map[string]*types.DescribeTaskListResponse) - for tl, tlm := range e.taskLists { - if tl.GetDomainID() == domainID && (taskListKind == nil || tlm.GetTaskListKind() == *taskListKind) { - if types.TaskListType(tl.GetType()) == types.TaskListTypeDecision { - decisionTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) - } else { - activityTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) - } - } - } - return &types.GetTaskListsByDomainResponse{ - DecisionTaskListMap: decisionTaskListMap, - ActivityTaskListMap: activityTaskListMap, - } -} - -// UnregisterManager implements tasklist.ManagerRegistry. -// It removes a task list manager from the engine's tracking map when the manager stops. -func (e *matchingEngineImpl) UnregisterManager(mgr tasklist.Manager) { - id := mgr.TaskListID() - e.taskListsLock.Lock() - defer e.taskListsLock.Unlock() - - // we need to make sure= we still hold the given `mgr` or we - // already created a new one. - currentTlMgr, ok := e.taskLists[*id] - if ok && currentTlMgr == mgr { - delete(e.taskLists, *id) - } - - e.metricsClient.Scope(metrics.MatchingTaskListMgrScope).UpdateGauge( - metrics.TaskListManagersGauge, - float64(len(e.taskLists)), - ) -} - // AddDecisionTask either delivers task directly to waiting poller or save it into task list persistence. func (e *matchingEngineImpl) AddDecisionTask( hCtx *handlerContext, @@ -1151,6 +1094,26 @@ func (e *matchingEngineImpl) listTaskListPartitions( return partitionHostInfo, nil } +func (e *matchingEngineImpl) getTaskListsByDomainAndKind(domainID string, taskListKind *types.TaskListKind) *types.GetTaskListsByDomainResponse { + decisionTaskListMap := make(map[string]*types.DescribeTaskListResponse) + activityTaskListMap := make(map[string]*types.DescribeTaskListResponse) + + for _, tlm := range e.taskListsRegistry.ManagersByDomainID(domainID) { + if taskListKind == nil || tlm.GetTaskListKind() == *taskListKind { + tl := tlm.TaskListID() + if types.TaskListType(tl.GetType()) == types.TaskListTypeDecision { + decisionTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) + } else { + activityTaskListMap[tl.GetRoot()] = tlm.DescribeTaskList(false) + } + } + } + return &types.GetTaskListsByDomainResponse{ + DecisionTaskListMap: decisionTaskListMap, + ActivityTaskListMap: activityTaskListMap, + } +} + func (e *matchingEngineImpl) GetTaskListsByDomain( hCtx *handlerContext, request *types.GetTaskListsByDomainRequest, @@ -1165,9 +1128,7 @@ func (e *matchingEngineImpl) GetTaskListsByDomain( tlKind = nil } - e.taskListsLock.RLock() - defer e.taskListsLock.RUnlock() - return e.getTaskListByDomainLocked(domainID, tlKind), nil + return e.getTaskListsByDomainAndKind(domainID, tlKind), nil } func (e *matchingEngineImpl) UpdateTaskListPartitionConfig( @@ -1276,17 +1237,10 @@ func (e *matchingEngineImpl) getAllPartitions( } func (e *matchingEngineImpl) unloadTaskList(tlMgr tasklist.Manager) { - id := tlMgr.TaskListID() - e.taskListsLock.Lock() - currentTlMgr, ok := e.taskLists[*id] - if !ok || tlMgr != currentTlMgr { - e.taskListsLock.Unlock() - return + unregistered := e.taskListsRegistry.Unregister(tlMgr) + if unregistered { + tlMgr.Stop() } - delete(e.taskLists, *id) - e.taskListsLock.Unlock() - // added a new taskList - tlMgr.Stop() } // Populate the decision task response based on context and scheduled/started events. @@ -1600,9 +1554,7 @@ func (e *matchingEngineImpl) domainChangeCallback(nextDomains []*cache.DomainCac taskListNormal := types.TaskListKindNormal - e.taskListsLock.RLock() - resp := e.getTaskListByDomainLocked(domain.GetInfo().ID, &taskListNormal) - e.taskListsLock.RUnlock() + resp := e.getTaskListsByDomainAndKind(domain.GetInfo().ID, &taskListNormal) for taskListName := range resp.DecisionTaskListMap { e.disconnectTaskListPollersAfterDomainFailover(taskListName, domain, persistence.TaskListTypeDecision, taskListNormal) @@ -1614,9 +1566,7 @@ func (e *matchingEngineImpl) domainChangeCallback(nextDomains []*cache.DomainCac taskListSticky := types.TaskListKindSticky - e.taskListsLock.RLock() - resp = e.getTaskListByDomainLocked(domain.GetInfo().ID, &taskListSticky) - e.taskListsLock.RUnlock() + resp = e.getTaskListsByDomainAndKind(domain.GetInfo().ID, &taskListSticky) for taskListName := range resp.DecisionTaskListMap { e.disconnectTaskListPollersAfterDomainFailover(taskListName, domain, persistence.TaskListTypeDecision, taskListSticky) diff --git a/service/matching/handler/engine_integration_test.go b/service/matching/handler/engine_integration_test.go index 38b6cb385b7..84dd9d97116 100644 --- a/service/matching/handler/engine_integration_test.go +++ b/service/matching/handler/engine_integration_test.go @@ -236,8 +236,8 @@ func (s *matchingEngineSuite) TestOnlyUnloadMatchingInstance() { ClusterMetadata: s.matchingEngine.clusterMetadata, IsolationState: s.matchingEngine.isolationState, MatchingClient: s.matchingEngine.matchingClient, - Registry: s.matchingEngine, // Engine implements ManagerRegistry - TaskList: taskListID, // same taskListID as above + Registry: s.matchingEngine.taskListsRegistry, + TaskList: taskListID, // same taskListID as above TaskListKind: tlKind, Cfg: s.matchingEngine.config, TimeSource: s.matchingEngine.timeSource, diff --git a/service/matching/handler/engine_test.go b/service/matching/handler/engine_test.go index 65ee593ece9..c1a631f079a 100644 --- a/service/matching/handler/engine_test.go +++ b/service/matching/handler/engine_test.go @@ -49,6 +49,20 @@ import ( "github.com/uber/cadence/service/sharddistributor/client/executorclient" ) +func mustNewIdentifier(t *testing.T, domainID, taskListName string, taskListType int) *tasklist.Identifier { + t.Helper() + id, err := tasklist.NewIdentifier(domainID, taskListName, taskListType) + require.NoError(t, err) + return id +} + +// newMockManagerWithTaskListID returns a MockManager with TaskListID() stubbed to return id (AnyTimes). +func newMockManagerWithTaskListID(ctrl *gomock.Controller, id *tasklist.Identifier) *tasklist.MockManager { + mgr := tasklist.NewMockManager(ctrl) + mgr.EXPECT().TaskListID().Return(id).AnyTimes() + return mgr +} + func TestGetTaskListsByDomain(t *testing.T) { testCases := []struct { name string @@ -171,42 +185,38 @@ func TestGetTaskListsByDomain(t *testing.T) { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) mockDomainCache := cache.NewMockDomainCache(mockCtrl) - decisionTasklistID, err := tasklist.NewIdentifier("test-domain-id", "decision0", 0) - require.NoError(t, err) - activityTasklistID, err := tasklist.NewIdentifier("test-domain-id", "activity0", 1) - require.NoError(t, err) - otherDomainTasklistID, err := tasklist.NewIdentifier("other-domain-id", "other0", 0) - require.NoError(t, err) - mockDecisionTaskListManager := tasklist.NewMockManager(mockCtrl) - mockActivityTaskListManager := tasklist.NewMockManager(mockCtrl) - mockOtherDomainTaskListManager := tasklist.NewMockManager(mockCtrl) + decisionTasklistID := mustNewIdentifier(t, "test-domain-id", "decision0", 0) + activityTasklistID := mustNewIdentifier(t, "test-domain-id", "activity0", 1) + otherDomainTasklistID := mustNewIdentifier(t, "other-domain-id", "other0", 0) + stickyTasklistID := mustNewIdentifier(t, "test-domain-id", "sticky0", 0) + mockDecisionTaskListManager := newMockManagerWithTaskListID(mockCtrl, decisionTasklistID) + mockActivityTaskListManager := newMockManagerWithTaskListID(mockCtrl, activityTasklistID) + mockOtherDomainTaskListManager := newMockManagerWithTaskListID(mockCtrl, otherDomainTasklistID) + mockStickyManager := newMockManagerWithTaskListID(mockCtrl, stickyTasklistID) mockTaskListManagers := map[tasklist.Identifier]*tasklist.MockManager{ *decisionTasklistID: mockDecisionTaskListManager, *activityTasklistID: mockActivityTaskListManager, *otherDomainTasklistID: mockOtherDomainTaskListManager, } - stickyTasklistID, err := tasklist.NewIdentifier("test-domain-id", "sticky0", 0) - require.NoError(t, err) - mockStickyManager := tasklist.NewMockManager(mockCtrl) mockStickyManagers := map[tasklist.Identifier]*tasklist.MockManager{ *stickyTasklistID: mockStickyManager, } tc.mockSetup(mockDomainCache, mockTaskListManagers, mockStickyManagers) + taskListRegistry := tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()) engine := &matchingEngineImpl{ - domainCache: mockDomainCache, - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *decisionTasklistID: mockDecisionTaskListManager, - *activityTasklistID: mockActivityTaskListManager, - *otherDomainTasklistID: mockOtherDomainTaskListManager, - *stickyTasklistID: mockStickyManager, - }, + domainCache: mockDomainCache, + taskListsRegistry: taskListRegistry, config: &config.Config{ EnableReturnAllTaskListKinds: func(opts ...dynamicproperties.FilterOption) bool { return tc.returnAllKinds }, }, } + taskListRegistry.Register(*decisionTasklistID, mockDecisionTaskListManager) + taskListRegistry.Register(*activityTasklistID, mockActivityTaskListManager) + taskListRegistry.Register(*otherDomainTasklistID, mockOtherDomainTaskListManager) + taskListRegistry.Register(*stickyTasklistID, mockStickyManager) resp, err := engine.GetTaskListsByDomain(nil, &types.GetTaskListsByDomainRequest{Domain: "test-domain"}) if tc.wantErr { @@ -372,19 +382,18 @@ func TestCancelOutstandingPoll(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - mockManager := tasklist.NewMockManager(mockCtrl) + tasklistID := mustNewIdentifier(t, "test-domain-id", "test-tasklist", 0) + mockManager := newMockManagerWithTaskListID(mockCtrl, tasklistID) executor := executorclient.NewMockExecutor[tasklist.ShardProcessor](mockCtrl) tc.mockSetup(mockCtrl, mockManager, executor) - tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 0) - require.NoError(t, err) + taskListRegistry := tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tasklistID: mockManager, - }, - executor: executor, + taskListsRegistry: taskListRegistry, + executor: executor, } + taskListRegistry.Register(*tasklistID, mockManager) hCtx := &handlerContext{Context: context.Background()} - err = engine.CancelOutstandingPoll(hCtx, tc.req) + err := engine.CancelOutstandingPoll(hCtx, tc.req) if tc.wantErr { require.Error(t, err) } else { @@ -559,18 +568,17 @@ func TestQueryWorkflow(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - mockManager := tasklist.NewMockManager(mockCtrl) + tasklistID := mustNewIdentifier(t, "test-domain-id", "test-tasklist", 0) + mockManager := newMockManagerWithTaskListID(mockCtrl, tasklistID) executor := executorclient.NewMockExecutor[tasklist.ShardProcessor](mockCtrl) - tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 0) - require.NoError(t, err) + taskListRegistry := tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tasklistID: mockManager, - }, + taskListsRegistry: taskListRegistry, timeSource: clock.NewRealTimeSource(), lockableQueryTaskMap: lockableQueryTaskMap{queryTaskMap: make(map[string]chan *queryResult)}, executor: executor, } + taskListRegistry.Register(*tasklistID, mockManager) tc.mockSetup(mockManager, &engine.lockableQueryTaskMap, mockCtrl, executor) resp, err := engine.QueryWorkflow(tc.hCtx, tc.req) if tc.wantErr { @@ -736,6 +744,7 @@ func TestIsShuttingDown(t *testing.T) { shutdownCompletion: &wg, shutdown: make(chan struct{}), executor: mockExecutor, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), } e.Start() assert.False(t, e.isShuttingDown()) @@ -750,13 +759,13 @@ func TestGetTasklistsNotOwned(t *testing.T) { resolver.EXPECT().WhoAmI().Return(membership.NewDetailedHostInfo("self", "host123", nil), nil) - tl1, _ := tasklist.NewIdentifier("", "tl1", 0) - tl2, _ := tasklist.NewIdentifier("", "tl2", 0) - tl3, _ := tasklist.NewIdentifier("", "tl3", 0) + tl1 := mustNewIdentifier(t, "", "tl1", 0) + tl2 := mustNewIdentifier(t, "", "tl2", 0) + tl3 := mustNewIdentifier(t, "", "tl3", 0) - tl1m := tasklist.NewMockManager(ctrl) - tl2m := tasklist.NewMockManager(ctrl) - tl3m := tasklist.NewMockManager(ctrl) + tl1m := newMockManagerWithTaskListID(ctrl, tl1) + tl2m := newMockManagerWithTaskListID(ctrl, tl2) + tl3m := newMockManagerWithTaskListID(ctrl, tl3) resolver.EXPECT().Lookup(service.Matching, tl1.GetName()).Return(membership.NewDetailedHostInfo("", "host123", nil), nil) resolver.EXPECT().Lookup(service.Matching, tl2.GetName()).Return(membership.NewDetailedHostInfo("", "host456", nil), nil) @@ -765,17 +774,15 @@ func TestGetTasklistsNotOwned(t *testing.T) { e := matchingEngineImpl{ shutdown: make(chan struct{}), membershipResolver: resolver, - taskListsLock: sync.RWMutex{}, - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tl1: tl1m, - *tl2: tl2m, - *tl3: tl3m, - }, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, logger: log.NewNoop(), } + e.taskListsRegistry.Register(*tl1, tl1m) + e.taskListsRegistry.Register(*tl2, tl2m) + e.taskListsRegistry.Register(*tl3, tl3m) tls, err := e.getNonOwnedTasklistsLocked() assert.NoError(t, err) @@ -790,13 +797,13 @@ func TestShutDownTasklistsNotOwned(t *testing.T) { resolver.EXPECT().WhoAmI().Return(membership.NewDetailedHostInfo("self", "host123", nil), nil) - tl1, _ := tasklist.NewIdentifier("", "tl1", 0) - tl2, _ := tasklist.NewIdentifier("", "tl2", 0) - tl3, _ := tasklist.NewIdentifier("", "tl3", 0) + tl1 := mustNewIdentifier(t, "", "tl1", 0) + tl2 := mustNewIdentifier(t, "", "tl2", 0) + tl3 := mustNewIdentifier(t, "", "tl3", 0) - tl1m := tasklist.NewMockManager(ctrl) - tl2m := tasklist.NewMockManager(ctrl) - tl3m := tasklist.NewMockManager(ctrl) + tl1m := newMockManagerWithTaskListID(ctrl, tl1) + tl2m := newMockManagerWithTaskListID(ctrl, tl2) + tl3m := newMockManagerWithTaskListID(ctrl, tl3) resolver.EXPECT().Lookup(service.Matching, tl1.GetName()).Return(membership.NewDetailedHostInfo("", "host123", nil), nil) resolver.EXPECT().Lookup(service.Matching, tl2.GetName()).Return(membership.NewDetailedHostInfo("", "host456", nil), nil) @@ -805,24 +812,21 @@ func TestShutDownTasklistsNotOwned(t *testing.T) { e := matchingEngineImpl{ shutdown: make(chan struct{}), membershipResolver: resolver, - taskListsLock: sync.RWMutex{}, - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tl1: tl1m, - *tl2: tl2m, - *tl3: tl3m, - }, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, metricsClient: metrics.NewNoopMetricsClient(), logger: log.NewNoop(), } + e.taskListsRegistry.Register(*tl1, tl1m) + e.taskListsRegistry.Register(*tl2, tl2m) + e.taskListsRegistry.Register(*tl3, tl3m) wg := sync.WaitGroup{} wg.Add(1) - tl2m.EXPECT().TaskListID().Return(tl2).AnyTimes() tl2m.EXPECT().String().AnyTimes() tl2m.EXPECT().Stop().Do(func() { @@ -1029,23 +1033,22 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { mockCtrl := gomock.NewController(t) mockDomainCache := cache.NewMockDomainCache(mockCtrl) mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) - mockManager := tasklist.NewMockManager(mockCtrl) + tasklistID := mustNewIdentifier(t, "test-domain-id", "test-tasklist", 1) + mockManager := newMockManagerWithTaskListID(mockCtrl, tasklistID) mockExecutor := executorclient.NewMockExecutor[tasklist.ShardProcessor](mockCtrl) tc.mockSetup(mockManager, mockCtrl, mockExecutor) - tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 1) - require.NoError(t, err) + taskListRegistry := tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tasklistID: mockManager, - }, - timeSource: clock.NewRealTimeSource(), - domainCache: mockDomainCache, + taskListsRegistry: taskListRegistry, + timeSource: clock.NewRealTimeSource(), + domainCache: mockDomainCache, config: &config.Config{ EnableAdaptiveScaler: dynamicproperties.GetBoolPropertyFilteredByTaskListInfo(tc.enableAdaptiveScaler), }, executor: mockExecutor, } - _, err = engine.UpdateTaskListPartitionConfig(tc.hCtx, tc.req) + taskListRegistry.Register(*tasklistID, mockManager) + _, err := engine.UpdateTaskListPartitionConfig(tc.hCtx, tc.req) if tc.expectError { assert.ErrorContains(t, err, tc.expectedError) } else { @@ -1210,22 +1213,20 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - mockManager := tasklist.NewMockManager(mockCtrl) + tasklistID := mustNewIdentifier(t, "test-domain-id", "test-tasklist", 1) + tasklistID2 := mustNewIdentifier(t, "test-domain-id", "/__cadence_sys/test-tasklist/1", 1) + mockManager := newMockManagerWithTaskListID(mockCtrl, tasklistID) mockExecutor := executorclient.NewMockExecutor[tasklist.ShardProcessor](mockCtrl) tc.mockSetup(mockManager, mockCtrl, mockExecutor) - tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 1) - require.NoError(t, err) - tasklistID2, err := tasklist.NewIdentifier("test-domain-id", "/__cadence_sys/test-tasklist/1", 1) - require.NoError(t, err) + taskListRegistry := tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tasklistID: mockManager, - *tasklistID2: mockManager, - }, - timeSource: clock.NewRealTimeSource(), - executor: mockExecutor, + taskListsRegistry: taskListRegistry, + timeSource: clock.NewRealTimeSource(), + executor: mockExecutor, } - _, err = engine.RefreshTaskListPartitionConfig(tc.hCtx, tc.req) + taskListRegistry.Register(*tasklistID, mockManager) + taskListRegistry.Register(*tasklistID2, mockManager) + _, err := engine.RefreshTaskListPartitionConfig(tc.hCtx, tc.req) if tc.expectError { assert.ErrorContains(t, err, tc.expectedError) } else { @@ -1241,67 +1242,82 @@ func Test_domainChangeCallback(t *testing.T) { clusters := []string{"cluster0", "cluster1"} - mockTaskListManagerGlobal1 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerGlobal2 := tasklist.NewMockManager(mockCtrl) - mockStickyTaskListManagerGlobal2 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerGlobal3 := tasklist.NewMockManager(mockCtrl) - mockStickyTaskListManagerGlobal3 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerLocal1 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerActiveActive1 := tasklist.NewMockManager(mockCtrl) + tlGlobalDecision1 := mustNewIdentifier(t, "global-domain-1-id", "global-domain-1", persistence.TaskListTypeDecision) + tlGlobalActivity1 := mustNewIdentifier(t, "global-domain-1-id", "global-domain-1", persistence.TaskListTypeActivity) + tlGlobalDecision2 := mustNewIdentifier(t, "global-domain-2-id", "global-domain-2", persistence.TaskListTypeDecision) + tlGlobalActivity2 := mustNewIdentifier(t, "global-domain-2-id", "global-domain-2", persistence.TaskListTypeActivity) + tlGlobalSticky2 := mustNewIdentifier(t, "global-domain-2-id", "sticky-global-domain-2", persistence.TaskListTypeDecision) + tlGlobalActivity3 := mustNewIdentifier(t, "global-domain-3-id", "global-domain-3", persistence.TaskListTypeActivity) + tlGlobalDecision3 := mustNewIdentifier(t, "global-domain-3-id", "global-domain-3", persistence.TaskListTypeDecision) + tlGlobalSticky3 := mustNewIdentifier(t, "global-domain-3-id", "sticky-global-domain-3", persistence.TaskListTypeDecision) + tlLocalDecision1 := mustNewIdentifier(t, "local-domain-1-id", "local-domain-1", persistence.TaskListTypeDecision) + tlLocalActivity1 := mustNewIdentifier(t, "local-domain-1-id", "local-domain-1", persistence.TaskListTypeActivity) + tlActiveActiveDecision1 := mustNewIdentifier(t, "active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeDecision) + tlActiveActiveActivity1 := mustNewIdentifier(t, "active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeActivity) + + newNormalManager := func(id *tasklist.Identifier) *tasklist.MockManager { + mgr := newMockManagerWithTaskListID(mockCtrl, id) + mgr.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).AnyTimes() + mgr.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).AnyTimes() + return mgr + } + newStickyManager := func(id *tasklist.Identifier) *tasklist.MockManager { + mgr := newMockManagerWithTaskListID(mockCtrl, id) + mgr.EXPECT().GetTaskListKind().Return(types.TaskListKindSticky).AnyTimes() + mgr.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).AnyTimes() + return mgr + } + + mockGlobalDecision1 := newNormalManager(tlGlobalDecision1) + mockGlobalActivity1 := newNormalManager(tlGlobalActivity1) + mockGlobalDecision2 := newNormalManager(tlGlobalDecision2) + mockGlobalActivity2 := newNormalManager(tlGlobalActivity2) + mockGlobalSticky2 := newStickyManager(tlGlobalSticky2) + mockGlobalDecision3 := newNormalManager(tlGlobalDecision3) + mockGlobalActivity3 := newNormalManager(tlGlobalActivity3) + mockGlobalSticky3 := newStickyManager(tlGlobalSticky3) + mockLocalDecision1 := newNormalManager(tlLocalDecision1) + mockLocalActivity1 := newNormalManager(tlLocalActivity1) + mockActiveActiveDecision1 := newNormalManager(tlActiveActiveDecision1) + mockActiveActiveActivity1 := newNormalManager(tlActiveActiveActivity1) mockExecutor := executorclient.NewMockExecutor[tasklist.ShardProcessor](mockCtrl) mockExecutor.EXPECT().GetShardProcess(gomock.Any(), gomock.Any()).Return(tasklist.NewMockShardProcessor(mockCtrl), nil).AnyTimes() - tlIdentifierDecisionGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeDecision) - tlIdentifierActivityGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeActivity) - tlIdentifierDecisionGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "global-domain-2", persistence.TaskListTypeDecision) - tlIdentifierActivityGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "global-domain-2", persistence.TaskListTypeActivity) - tlIdentifierStickyGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "sticky-global-domain-2", persistence.TaskListTypeDecision) - tlIdentifierActivityGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "global-domain-3", persistence.TaskListTypeActivity) - tlIdentifierDecisionGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "global-domain-3", persistence.TaskListTypeDecision) - tlIdentifierStickyGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "sticky-global-domain-3", persistence.TaskListTypeDecision) - tlIdentifierDecisionLocal1, _ := tasklist.NewIdentifier("local-domain-1-id", "local-domain-1", persistence.TaskListTypeDecision) - tlIdentifierActivityLocal1, _ := tasklist.NewIdentifier("local-domain-1-id", "local-domain-1", persistence.TaskListTypeActivity) - tlIdentifierDecisionActiveActive1, _ := tasklist.NewIdentifier("active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeDecision) - tlIdentifierActivityActiveActive1, _ := tasklist.NewIdentifier("active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeActivity) - engine := &matchingEngineImpl{ domainCache: mockDomainCache, failoverNotificationVersion: 1, config: defaultTestConfig(), logger: log.NewNoop(), - taskLists: map[tasklist.Identifier]tasklist.Manager{ - *tlIdentifierDecisionGlobal1: mockTaskListManagerGlobal1, - *tlIdentifierActivityGlobal1: mockTaskListManagerGlobal1, - *tlIdentifierDecisionGlobal2: mockTaskListManagerGlobal2, - *tlIdentifierActivityGlobal2: mockTaskListManagerGlobal2, - *tlIdentifierStickyGlobal2: mockStickyTaskListManagerGlobal2, - *tlIdentifierDecisionGlobal3: mockTaskListManagerGlobal3, - *tlIdentifierActivityGlobal3: mockTaskListManagerGlobal3, - *tlIdentifierStickyGlobal3: mockStickyTaskListManagerGlobal3, - *tlIdentifierDecisionLocal1: mockTaskListManagerLocal1, - *tlIdentifierActivityLocal1: mockTaskListManagerLocal1, - *tlIdentifierDecisionActiveActive1: mockTaskListManagerActiveActive1, - *tlIdentifierActivityActiveActive1: mockTaskListManagerActiveActive1, - }, - executor: mockExecutor, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), + executor: mockExecutor, } - - mockTaskListManagerGlobal1.EXPECT().ReleaseBlockedPollers().Times(0) - mockTaskListManagerGlobal2.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).Times(4) - mockStickyTaskListManagerGlobal2.EXPECT().GetTaskListKind().Return(types.TaskListKindSticky).Times(2) - mockTaskListManagerGlobal2.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(2) - mockStickyTaskListManagerGlobal2.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(1) - mockTaskListManagerGlobal2.EXPECT().ReleaseBlockedPollers().Times(2) - mockStickyTaskListManagerGlobal2.EXPECT().ReleaseBlockedPollers().Times(1) - mockTaskListManagerGlobal3.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).Times(4) - mockStickyTaskListManagerGlobal3.EXPECT().GetTaskListKind().Return(types.TaskListKindSticky).Times(2) - mockTaskListManagerGlobal3.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(2) - mockStickyTaskListManagerGlobal3.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(1) - mockTaskListManagerGlobal3.EXPECT().ReleaseBlockedPollers().Return(errors.New("some-error")).Times(2) - mockStickyTaskListManagerGlobal3.EXPECT().ReleaseBlockedPollers().Return(errors.New("some-error")).Times(1) - mockTaskListManagerLocal1.EXPECT().ReleaseBlockedPollers().Times(0) - mockTaskListManagerActiveActive1.EXPECT().ReleaseBlockedPollers().Times(0) + engine.taskListsRegistry.Register(*tlGlobalDecision1, mockGlobalDecision1) + engine.taskListsRegistry.Register(*tlGlobalActivity1, mockGlobalActivity1) + engine.taskListsRegistry.Register(*tlGlobalDecision2, mockGlobalDecision2) + engine.taskListsRegistry.Register(*tlGlobalActivity2, mockGlobalActivity2) + engine.taskListsRegistry.Register(*tlGlobalSticky2, mockGlobalSticky2) + engine.taskListsRegistry.Register(*tlGlobalDecision3, mockGlobalDecision3) + engine.taskListsRegistry.Register(*tlGlobalActivity3, mockGlobalActivity3) + engine.taskListsRegistry.Register(*tlGlobalSticky3, mockGlobalSticky3) + engine.taskListsRegistry.Register(*tlLocalDecision1, mockLocalDecision1) + engine.taskListsRegistry.Register(*tlLocalActivity1, mockLocalActivity1) + engine.taskListsRegistry.Register(*tlActiveActiveDecision1, mockActiveActiveDecision1) + engine.taskListsRegistry.Register(*tlActiveActiveActivity1, mockActiveActiveActivity1) + + // Eligible for failover handling is defined by isDomainEligibleToDisconnectPollers. + mockGlobalDecision1.EXPECT().ReleaseBlockedPollers().Times(0) // global-domain-1 has failover version 0 (<= current 1), so not eligible. + mockGlobalActivity1.EXPECT().ReleaseBlockedPollers().Times(0) // global-domain-1 has failover version 0 (<= current 1), so not eligible. + mockGlobalDecision2.EXPECT().ReleaseBlockedPollers().Times(1) // global-domain-2 is global, non-active-active, and version 4 > 1. + mockGlobalActivity2.EXPECT().ReleaseBlockedPollers().Times(1) // global-domain-2 is global, non-active-active, and version 4 > 1. + mockGlobalSticky2.EXPECT().ReleaseBlockedPollers().Times(1) // sticky task list under eligible global-domain-2. + mockGlobalDecision3.EXPECT().ReleaseBlockedPollers().Times(1) // global-domain-3 is eligible. + mockGlobalActivity3.EXPECT().ReleaseBlockedPollers().Times(1) // global-domain-3 is eligible. + mockGlobalSticky3.EXPECT().ReleaseBlockedPollers().Times(1) // sticky task list under eligible global-domain-3. + mockLocalDecision1.EXPECT().ReleaseBlockedPollers().Times(0) // local domains are not eligible. + mockLocalActivity1.EXPECT().ReleaseBlockedPollers().Times(0) // local domains are not eligible. + mockActiveActiveDecision1.EXPECT().ReleaseBlockedPollers().Times(0) // active-active domains are not eligible. + mockActiveActiveActivity1.EXPECT().ReleaseBlockedPollers().Times(0) // active-active domains are not eligible. domains := []*cache.DomainCacheEntry{ cache.NewDomainCacheEntryForTest( @@ -1374,7 +1390,7 @@ func Test_domainChangeCallback(t *testing.T) { engine.domainChangeCallback(domains) - assert.Equal(t, int64(5), engine.failoverNotificationVersion) + assert.Equal(t, int64(5), engine.failoverNotificationVersion, "5 is the highest failover notification version in the fixtures (global-domain-3)") } func Test_registerDomainFailoverCallback(t *testing.T) { @@ -1402,7 +1418,7 @@ func Test_registerDomainFailoverCallback(t *testing.T) { failoverNotificationVersion: 0, config: defaultTestConfig(), logger: log.NewNoop(), - taskLists: map[tasklist.Identifier]tasklist.Manager{}, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), } engine.registerDomainFailoverCallback() diff --git a/service/matching/handler/membership.go b/service/matching/handler/membership.go index dfd21de2eac..dad59a5b987 100644 --- a/service/matching/handler/membership.go +++ b/service/matching/handler/membership.go @@ -127,22 +127,21 @@ func (e *matchingEngineImpl) getNonOwnedTasklistsLocked() ([]tasklist.Manager, e var toShutDown []tasklist.Manager - e.taskListsLock.RLock() - defer e.taskListsLock.RUnlock() + taskLists := e.taskListsRegistry.AllManagers() self, err := e.membershipResolver.WhoAmI() if err != nil { return nil, fmt.Errorf("failed to lookup self im membership: %w", err) } - for tl, manager := range e.taskLists { - taskListOwner, err := e.membershipResolver.Lookup(service.Matching, tl.GetName()) + for _, tl := range taskLists { + taskListOwner, err := e.membershipResolver.Lookup(service.Matching, tl.TaskListID().GetName()) if err != nil { return nil, fmt.Errorf("failed to lookup task list owner: %w", err) } if taskListOwner.Identity() != self.Identity() { - toShutDown = append(toShutDown, manager) + toShutDown = append(toShutDown, tl) } } diff --git a/service/matching/handler/membership_test.go b/service/matching/handler/membership_test.go index 5ba20c39c25..461e0af51ed 100644 --- a/service/matching/handler/membership_test.go +++ b/service/matching/handler/membership_test.go @@ -197,13 +197,12 @@ func TestSubscriptionAndShutdown(t *testing.T) { engine := matchingEngineImpl{ shutdownCompletion: &shutdownWG, membershipResolver: mockResolver, - config: &config.Config{ - EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, - }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), - domainCache: mockDomainCache, - executor: mockExecutor, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), + config: &config.Config{EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }}, + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, + executor: mockExecutor, } mockResolver.EXPECT().WhoAmI().Return(membership.NewDetailedHostInfo("host2", "host2", nil), nil).AnyTimes() @@ -233,13 +232,12 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { engine := matchingEngineImpl{ shutdownCompletion: &shutdownWG, membershipResolver: mockResolver, - config: &config.Config{ - EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, - }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), - domainCache: mockDomainCache, - executor: mockExecutor, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), + config: &config.Config{EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }}, + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, + executor: mockExecutor, } // this should trigger the error case on a membership event @@ -288,12 +286,11 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { engine := matchingEngineImpl{ shutdownCompletion: &shutdownWG, membershipResolver: mockResolver, - config: &config.Config{ - EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, - }, - shutdown: make(chan struct{}), - logger: logger, - domainCache: mockDomainCache, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), + config: &config.Config{EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }}, + shutdown: make(chan struct{}), + logger: logger, + domainCache: mockDomainCache, } mockResolver.EXPECT().Subscribe(service.Matching, "matching-engine", gomock.Any()). @@ -337,13 +334,12 @@ func TestGetTasklistManagerShutdownScenario(t *testing.T) { engine := matchingEngineImpl{ shutdownCompletion: &shutdownWG, membershipResolver: mockResolver, - config: &config.Config{ - EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, - }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), - domainCache: mockDomainCache, - executor: mockExecutor, + taskListsRegistry: tasklist.NewManagerRegistry(metrics.NewNoopMetricsClient()), + config: &config.Config{EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }}, + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, + executor: mockExecutor, } // set this engine to be shutting down to trigger the tasklistGetTasklistByID guard diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go index d7ae192b33b..7c573805a15 100644 --- a/service/matching/tasklist/interfaces.go +++ b/service/matching/tasklist/interfaces.go @@ -41,9 +41,23 @@ type ( // ManagerRegistry is implemented by components that track/own task list managers. // Managers notify their registry when they stop so they can be cleaned up. ManagerRegistry interface { - // UnregisterManager is called by a Manager when it stops, allowing the registry - // to clean up resources and remove the manager from its tracking structures. - UnregisterManager(mgr Manager) + // Register registers a manager for a given identifier. + // we can override the manager for the same identifier if it is already registered + // this case should be handled by the caller + Register(id Identifier, mgr Manager) + + // Unregister unregisters a manager for a given identifier. + // it returns true if the manager was unregistered, false if it was not found + Unregister(mgr Manager) bool + + // AllManagers returns a list of all managers. + AllManagers() []Manager + + ManagersByDomainID(domainID string) []Manager + ManagersByTaskListName(name string) []Manager + // ManagerByTaskListIdentifier returns a manager for a given identifier. + // it returns the manager and true if it was found, false if it was not found + ManagerByTaskListIdentifier(id Identifier) (Manager, bool) } Manager interface { diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go index 5c3f3844933..e2b257d1b8d 100644 --- a/service/matching/tasklist/interfaces_mock.go +++ b/service/matching/tasklist/interfaces_mock.go @@ -44,16 +44,87 @@ func (m *MockManagerRegistry) EXPECT() *MockManagerRegistryMockRecorder { return m.recorder } -// UnregisterManager mocks base method. -func (m *MockManagerRegistry) UnregisterManager(mgr Manager) { +// AllManagers mocks base method. +func (m *MockManagerRegistry) AllManagers() []Manager { m.ctrl.T.Helper() - m.ctrl.Call(m, "UnregisterManager", mgr) + ret := m.ctrl.Call(m, "AllManagers") + ret0, _ := ret[0].([]Manager) + return ret0 +} + +// AllManagers indicates an expected call of AllManagers. +func (mr *MockManagerRegistryMockRecorder) AllManagers() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllManagers", reflect.TypeOf((*MockManagerRegistry)(nil).AllManagers)) +} + +// ManagerByTaskListIdentifier mocks base method. +func (m *MockManagerRegistry) ManagerByTaskListIdentifier(id Identifier) (Manager, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ManagerByTaskListIdentifier", id) + ret0, _ := ret[0].(Manager) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// ManagerByTaskListIdentifier indicates an expected call of ManagerByTaskListIdentifier. +func (mr *MockManagerRegistryMockRecorder) ManagerByTaskListIdentifier(id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ManagerByTaskListIdentifier", reflect.TypeOf((*MockManagerRegistry)(nil).ManagerByTaskListIdentifier), id) +} + +// ManagersByDomainID mocks base method. +func (m *MockManagerRegistry) ManagersByDomainID(domainID string) []Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ManagersByDomainID", domainID) + ret0, _ := ret[0].([]Manager) + return ret0 +} + +// ManagersByDomainID indicates an expected call of ManagersByDomainID. +func (mr *MockManagerRegistryMockRecorder) ManagersByDomainID(domainID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ManagersByDomainID", reflect.TypeOf((*MockManagerRegistry)(nil).ManagersByDomainID), domainID) +} + +// ManagersByTaskListName mocks base method. +func (m *MockManagerRegistry) ManagersByTaskListName(name string) []Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ManagersByTaskListName", name) + ret0, _ := ret[0].([]Manager) + return ret0 +} + +// ManagersByTaskListName indicates an expected call of ManagersByTaskListName. +func (mr *MockManagerRegistryMockRecorder) ManagersByTaskListName(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ManagersByTaskListName", reflect.TypeOf((*MockManagerRegistry)(nil).ManagersByTaskListName), name) +} + +// Register mocks base method. +func (m *MockManagerRegistry) Register(id Identifier, mgr Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Register", id, mgr) +} + +// Register indicates an expected call of Register. +func (mr *MockManagerRegistryMockRecorder) Register(id, mgr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockManagerRegistry)(nil).Register), id, mgr) +} + +// Unregister mocks base method. +func (m *MockManagerRegistry) Unregister(mgr Manager) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unregister", mgr) + ret0, _ := ret[0].(bool) + return ret0 } -// UnregisterManager indicates an expected call of UnregisterManager. -func (mr *MockManagerRegistryMockRecorder) UnregisterManager(mgr any) *gomock.Call { +// Unregister indicates an expected call of Unregister. +func (mr *MockManagerRegistryMockRecorder) Unregister(mgr any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterManager", reflect.TypeOf((*MockManagerRegistry)(nil).UnregisterManager), mgr) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unregister", reflect.TypeOf((*MockManagerRegistry)(nil).Unregister), mgr) } // MockManager is a mock of Manager interface. diff --git a/service/matching/tasklist/shard_processor.go b/service/matching/tasklist/shard_processor.go index c2d98ba7220..669c7d67b39 100644 --- a/service/matching/tasklist/shard_processor.go +++ b/service/matching/tasklist/shard_processor.go @@ -13,23 +13,21 @@ import ( ) type ShardProcessorParams struct { - ShardID string - TaskListsLock *sync.RWMutex - TaskLists map[Identifier]Manager - ReportTTL time.Duration - TimeSource clock.TimeSource + ShardID string + TaskListsRegistry ManagerRegistry + ReportTTL time.Duration + TimeSource clock.TimeSource } type shardProcessorImpl struct { - shardID string - taskListsLock *sync.RWMutex // locks mutation of taskLists - taskLists map[Identifier]Manager // Convert to LRU cache - Status atomic.Int32 - reportLock sync.RWMutex - shardReport executorclient.ShardReport - reportTime time.Time - reportTTL time.Duration - timeSource clock.TimeSource + shardID string + taskListsRegistry ManagerRegistry + Status atomic.Int32 + reportLock sync.RWMutex + shardReport executorclient.ShardReport + reportTime time.Time + reportTTL time.Duration + timeSource clock.TimeSource } func NewShardProcessor(params ShardProcessorParams) (ShardProcessor, error) { @@ -38,13 +36,12 @@ func NewShardProcessor(params ShardProcessorParams) (ShardProcessor, error) { return nil, err } shardprocessor := &shardProcessorImpl{ - shardID: params.ShardID, - taskListsLock: params.TaskListsLock, - taskLists: params.TaskLists, - shardReport: executorclient.ShardReport{}, - reportTime: params.TimeSource.Now(), - reportTTL: params.ReportTTL, - timeSource: params.TimeSource, + shardID: params.ShardID, + taskListsRegistry: params.TaskListsRegistry, + shardReport: executorclient.ShardReport{}, + reportTime: params.TimeSource.Now(), + reportTTL: params.ReportTTL, + timeSource: params.TimeSource, } shardprocessor.SetShardStatus(types.ShardStatusREADY) shardprocessor.shardReport = executorclient.ShardReport{ @@ -64,14 +61,7 @@ func (sp *shardProcessorImpl) Start(ctx context.Context) error { // Stop is stopping the tasklist when a shard is not assigned to this executor anymore. func (sp *shardProcessorImpl) Stop() { - sp.taskListsLock.RLock() - var toShutDown []Manager - for _, tlMgr := range sp.taskLists { - if tlMgr.TaskListID().name == sp.shardID { - toShutDown = append(toShutDown, tlMgr) - } - } - sp.taskListsLock.RUnlock() + toShutDown := sp.taskListsRegistry.ManagersByTaskListName(sp.shardID) for _, tlMgr := range toShutDown { tlMgr.Stop() } @@ -97,19 +87,15 @@ func (sp *shardProcessorImpl) SetShardStatus(status types.ShardStatus) { } func (sp *shardProcessorImpl) getShardLoad() float64 { - sp.taskListsLock.RLock() - defer sp.taskListsLock.RUnlock() var load float64 // We assign a shard only based on the task list name // so task lists of different task type (decisions/activities), of different kind (normal, sticky, ephemeral) or partitions // will be assigned all to the same matching instance (executor) // we need to sum the rps for each of the tasklist to calculate the load. - for _, tlMgr := range sp.taskLists { - if tlMgr.TaskListID().name == sp.shardID { - qps := tlMgr.QueriesPerSecond() - load = load + qps - } + for _, tlMgr := range sp.taskListsRegistry.ManagersByTaskListName(sp.shardID) { + qps := tlMgr.QueriesPerSecond() + load = load + qps } return load } @@ -118,11 +104,8 @@ func validateSPParams(params ShardProcessorParams) error { if params.ShardID == "" { return errors.New("ShardID must be specified") } - if params.TaskListsLock == nil { - return errors.New("TaskListsLock must be specified") - } - if params.TaskLists == nil { - return errors.New("TaskLists must be specified") + if params.TaskListsRegistry == nil { + return errors.New("TaskListsRegistry must be specified") } if params.TimeSource == nil { return errors.New("TimeSource must be specified") diff --git a/service/matching/tasklist/shard_processor_factory.go b/service/matching/tasklist/shard_processor_factory.go index 38ebbd946a3..9bf02455b07 100644 --- a/service/matching/tasklist/shard_processor_factory.go +++ b/service/matching/tasklist/shard_processor_factory.go @@ -1,7 +1,6 @@ package tasklist import ( - "sync" "time" "github.com/uber/cadence/common/clock" @@ -9,20 +8,18 @@ import ( // ShardProcessorFactory is a generic factory for creating ShardProcessor instances. type ShardProcessorFactory struct { - TaskListsLock *sync.RWMutex // locks mutation of taskLists - TaskLists map[Identifier]Manager // Convert to LRU cache - ReportTTL time.Duration - TimeSource clock.TimeSource + TaskListsRegistry ManagerRegistry + ReportTTL time.Duration + TimeSource clock.TimeSource } func (spf ShardProcessorFactory) NewShardProcessor(shardID string) (ShardProcessor, error) { params := ShardProcessorParams{ - ShardID: shardID, - TaskListsLock: spf.TaskListsLock, - TaskLists: spf.TaskLists, - ReportTTL: spf.ReportTTL, - TimeSource: spf.TimeSource, + ShardID: shardID, + TaskListsRegistry: spf.TaskListsRegistry, + ReportTTL: spf.ReportTTL, + TimeSource: spf.TimeSource, } return NewShardProcessor(params) } diff --git a/service/matching/tasklist/shard_processor_test.go b/service/matching/tasklist/shard_processor_test.go index 329f5fdcdbd..5a4a16e5b3b 100644 --- a/service/matching/tasklist/shard_processor_test.go +++ b/service/matching/tasklist/shard_processor_test.go @@ -1,7 +1,6 @@ package tasklist import ( - "sync" "testing" "time" @@ -14,98 +13,65 @@ import ( "github.com/uber/cadence/common/types" ) -func paramsForTaskListManager(taskListID *Identifier) ShardProcessorParams { - var mutex sync.RWMutex - taskList := make(map[Identifier]Manager) - params := ShardProcessorParams{ - ShardID: taskListID.GetName(), - TaskListsLock: &mutex, - TaskLists: taskList, - ReportTTL: 1 * time.Millisecond, - TimeSource: clock.NewRealTimeSource(), +func mustNewIdentifier(domainID, name string, taskType int) *Identifier { + id, err := NewIdentifier(domainID, name, taskType) + if err != nil { + panic(err) } - return params + return id } -func paramsForTaskListManagerWithStopCallback(t *testing.T, taskListID *Identifier) ShardProcessorParams { - params := paramsForTaskListManager(taskListID) - mockCtrl := gomock.NewController(t) - mockManager := NewMockManager(mockCtrl) - params.TaskLists[*taskListID] = mockManager - mockManager.EXPECT().TaskListID().Return( - taskListID).Times(1) - mockManager.EXPECT().Stop().Do( - func() { - delete(params.TaskLists, *taskListID) - }, - ) - return params +var testIdentifier = mustNewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) + +type shardProcessorTestData struct { + mockRegistry *MockManagerRegistry + shardProcessor ShardProcessor } -func TestNewShardProcessor(t *testing.T) { - t.Run("NewShardProcessor fails with empty params", func(t *testing.T) { - params := ShardProcessorParams{} - sp, err := NewShardProcessor(params) - require.Nil(t, sp) - require.Error(t, err) - }) +func newShardProcessorTestData(t *testing.T, taskListID *Identifier) shardProcessorTestData { + ctrl := gomock.NewController(t) - t.Run("NewShardProcessor success", func(t *testing.T) { - tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) - require.NoError(t, err) - params := paramsForTaskListManager(tlID) - sp, err := NewShardProcessor(params) - require.NoError(t, err) - require.NotNil(t, sp) - }) -} + mockRegistry := NewMockManagerRegistry(ctrl) + mockRegistry.EXPECT().ManagersByTaskListName(taskListID.GetName()).Return([]Manager{}).AnyTimes() -func TestStop(t *testing.T) { - t.Run("Stop ShardProcessor", func(t *testing.T) { - tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) - require.NoError(t, err) - params := paramsForTaskListManagerWithStopCallback(t, tlID) + params := ShardProcessorParams{ + ShardID: taskListID.GetName(), + TaskListsRegistry: mockRegistry, + ReportTTL: 1 * time.Millisecond, + TimeSource: clock.NewRealTimeSource(), + } - sp, err := NewShardProcessor(params) - require.NoError(t, err) - params.TaskListsLock.RLock() - require.Equal(t, 1, len(params.TaskLists)) - params.TaskListsLock.RUnlock() + shardProcessor, err := NewShardProcessor(params) + require.NoError(t, err) + return shardProcessorTestData{ + mockRegistry: mockRegistry, + shardProcessor: shardProcessor, + } +} - sp.Stop() - params.TaskListsLock.RLock() - require.Equal(t, 0, len(params.TaskLists)) - params.TaskListsLock.RUnlock() - }) +func TestNewShardProcessorFailsWithEmptyParams(t *testing.T) { + params := ShardProcessorParams{} + sp, err := NewShardProcessor(params) + require.Nil(t, sp) + require.Error(t, err) } func TestGetShardReport(t *testing.T) { - t.Run("GetShardReport success", func(t *testing.T) { - tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) - require.NoError(t, err) - params := paramsForTaskListManager(tlID) - sp, err := NewShardProcessor(params) - require.NoError(t, err) - shardReport := sp.GetShardReport() - require.NotNil(t, shardReport) - require.Equal(t, float64(0), shardReport.ShardLoad) - require.Equal(t, types.ShardStatusREADY, shardReport.Status) - }) + td := newShardProcessorTestData(t, testIdentifier) + + shardReport := td.shardProcessor.GetShardReport() + require.NotNil(t, shardReport) + require.Equal(t, float64(0), shardReport.ShardLoad) + require.Equal(t, types.ShardStatusREADY, shardReport.Status) } func TestSetShardStatus(t *testing.T) { defer goleak.VerifyNone(t) + td := newShardProcessorTestData(t, testIdentifier) - t.Run("SetShardStatus success", func(t *testing.T) { - tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) - require.NoError(t, err) - params := paramsForTaskListManager(tlID) - sp, err := NewShardProcessor(params) - require.NoError(t, err) - sp.SetShardStatus(types.ShardStatusREADY) - shardReport := sp.GetShardReport() - require.NotNil(t, shardReport) - require.Equal(t, float64(0), shardReport.ShardLoad) - require.Equal(t, types.ShardStatusREADY, shardReport.Status) - }) + td.shardProcessor.SetShardStatus(types.ShardStatusREADY) + shardReport := td.shardProcessor.GetShardReport() + require.NotNil(t, shardReport) + require.Equal(t, float64(0), shardReport.ShardLoad) + require.Equal(t, types.ShardStatusREADY, shardReport.Status) } diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index abdd0d2876d..46eeb3709b9 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -318,7 +318,7 @@ func (c *taskListManagerImpl) Stop() { } // Notify parent registry to unregister this manager - c.registry.UnregisterManager(c) + c.registry.Unregister(c) if c.adaptiveScaler != nil { c.adaptiveScaler.Stop() diff --git a/service/matching/tasklist/task_list_manager_test.go b/service/matching/tasklist/task_list_manager_test.go index 6f7ff70a53e..a7c4ef22865 100644 --- a/service/matching/tasklist/task_list_manager_test.go +++ b/service/matching/tasklist/task_list_manager_test.go @@ -91,7 +91,7 @@ func setupMocksForTaskListManager(t *testing.T, taskListID *Identifier, taskList config := config.NewConfig(dynamicconfig.NewCollection(dynamicClient, logger), "hostname", commonConfig.RPC{}, getIsolationgroupsHelper) mockHistoryService := history.NewMockClient(ctrl) mockRegistry := NewMockManagerRegistry(ctrl) - mockRegistry.EXPECT().UnregisterManager(gomock.Any()).AnyTimes() + mockRegistry.EXPECT().Unregister(gomock.Any()).AnyTimes() params := ManagerParams{ DomainCache: deps.mockDomainCache, Logger: logger, @@ -239,7 +239,7 @@ func createTestTaskListManagerWithConfig(t *testing.T, logger log.Logger, contro mockMatchingClient.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockHistoryService := history.NewMockClient(controller) mockRegistry := NewMockManagerRegistry(controller) - mockRegistry.EXPECT().UnregisterManager(gomock.Any()).AnyTimes() + mockRegistry.EXPECT().Unregister(gomock.Any()).AnyTimes() tl := "tl" dID := "domain" tlID, err := NewIdentifier(dID, tl, persistence.TaskListTypeActivity) @@ -284,14 +284,14 @@ func TestTaskListManagerRegistryNotification(t *testing.T) { // Replace the registry with our mock tlm.registry = mockRegistry - // Expect UnregisterManager to be called exactly once with the manager instance - mockRegistry.EXPECT().UnregisterManager(tlm).Times(1) + // Expect Unregister to be called exactly once with the manager instance + mockRegistry.EXPECT().Unregister(tlm).Times(1) // Start the manager err := tlm.Start(context.Background()) require.NoError(t, err) - // Stop should call UnregisterManager + // Stop should call Unregister tlm.Stop() // Verify the manager stopped require.Equal(t, int32(1), tlm.stopped) @@ -965,7 +965,7 @@ func TestTaskListManagerGetTaskBatch(t *testing.T) { cfg.RangeSize = rangeSize cfg.ReadRangeSize = dynamicproperties.GetIntPropertyFn(rangeSize / 2) mockRegistry := NewMockManagerRegistry(controller) - mockRegistry.EXPECT().UnregisterManager(gomock.Any()).AnyTimes() + mockRegistry.EXPECT().Unregister(gomock.Any()).AnyTimes() params := ManagerParams{ DomainCache: mockDomainCache, Logger: logger, @@ -1099,7 +1099,7 @@ func TestTaskListReaderPumpAdvancesAckLevelAfterEmptyReads(t *testing.T) { cfg.ReadRangeSize = dynamicproperties.GetIntPropertyFn(rangeSize / 2) mockRegistry := NewMockManagerRegistry(controller) - mockRegistry.EXPECT().UnregisterManager(gomock.Any()).AnyTimes() + mockRegistry.EXPECT().Unregister(gomock.Any()).AnyTimes() params := ManagerParams{ DomainCache: mockDomainCache, Logger: logger, @@ -1249,7 +1249,7 @@ func TestTaskExpiryAndCompletion(t *testing.T) { // on enqueuing a task to task buffer cfg.IdleTasklistCheckInterval = dynamicproperties.GetDurationPropertyFnFilteredByTaskListInfo(20 * time.Millisecond) mockRegistry := NewMockManagerRegistry(controller) - mockRegistry.EXPECT().UnregisterManager(gomock.Any()).AnyTimes() + mockRegistry.EXPECT().Unregister(gomock.Any()).AnyTimes() params := ManagerParams{ DomainCache: mockDomainCache, Logger: logger, diff --git a/service/matching/tasklist/task_list_registry.go b/service/matching/tasklist/task_list_registry.go new file mode 100644 index 00000000000..266c19cb2e8 --- /dev/null +++ b/service/matching/tasklist/task_list_registry.go @@ -0,0 +1,98 @@ +package tasklist + +import ( + "sync" + + "github.com/uber/cadence/common/metrics" +) + +type taskListRegistryImpl struct { + sync.RWMutex + taskLists map[Identifier]Manager + metricsClient metrics.Client +} + +func NewManagerRegistry(metricsClient metrics.Client) ManagerRegistry { + return &taskListRegistryImpl{ + taskLists: make(map[Identifier]Manager), + metricsClient: metricsClient, + } +} + +func (r *taskListRegistryImpl) Register(id Identifier, mgr Manager) { + r.Lock() + defer r.Unlock() + + // we can override the manager for the same identifier if it is already registered + // this case should be handled by the caller + r.taskLists[id] = mgr + r.updateMetricsLocked() +} + +func (r *taskListRegistryImpl) Unregister(mgr Manager) bool { + id := mgr.TaskListID() + r.Lock() + defer r.Unlock() + + // we need to make sure we still hold the given `mgr` or we already replaced with a new one. + currentTlMgr, ok := r.taskLists[*id] + if ok && currentTlMgr == mgr { + delete(r.taskLists, *id) + r.updateMetricsLocked() + return true + } + + return false +} + +func (r *taskListRegistryImpl) ManagersByDomainID(domainID string) []Manager { + r.RLock() + defer r.RUnlock() + + var res []Manager + for tl, tlm := range r.taskLists { + if tl.GetDomainID() == domainID { + res = append(res, tlm) + } + } + return res +} + +func (r *taskListRegistryImpl) ManagersByTaskListName(name string) []Manager { + r.RLock() + defer r.RUnlock() + + var res []Manager + for _, tlm := range r.taskLists { + if tlm.TaskListID().GetName() == name { + res = append(res, tlm) + } + } + return res +} + +func (r *taskListRegistryImpl) ManagerByTaskListIdentifier(id Identifier) (Manager, bool) { + r.RLock() + defer r.RUnlock() + + tlMgr, ok := r.taskLists[id] + return tlMgr, ok +} + +func (r *taskListRegistryImpl) AllManagers() []Manager { + r.RLock() + defer r.RUnlock() + + res := make([]Manager, 0, len(r.taskLists)) + for _, tlMgr := range r.taskLists { + res = append(res, tlMgr) + } + return res +} + +func (r *taskListRegistryImpl) updateMetricsLocked() { + r.metricsClient.Scope(metrics.MatchingTaskListMgrScope).UpdateGauge( + metrics.TaskListManagersGauge, + float64(len(r.taskLists)), + ) +} diff --git a/service/matching/tasklist/task_list_registry_test.go b/service/matching/tasklist/task_list_registry_test.go new file mode 100644 index 00000000000..9b9349bacaf --- /dev/null +++ b/service/matching/tasklist/task_list_registry_test.go @@ -0,0 +1,105 @@ +package tasklist + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/uber/cadence/common/metrics" + metricsmocks "github.com/uber/cadence/common/metrics/mocks" + "github.com/uber/cadence/common/persistence" +) + +func mustNewIdentifierForTest(t *testing.T, domainID, taskListName string) *Identifier { + t.Helper() + id, err := NewIdentifier(domainID, taskListName, persistence.TaskListTypeDecision) + require.NoError(t, err) + return id +} + +func newMockManagerWithID(t *testing.T, ctrl *gomock.Controller, id *Identifier) *MockManager { + t.Helper() + mgr := NewMockManager(ctrl) + mgr.EXPECT().TaskListID().Return(id).AnyTimes() + return mgr +} + +func TestTaskListRegistry_RegisterLookupAndUnregister(t *testing.T) { + ctrl := gomock.NewController(t) + metricsClient := metricsmocks.Client{} + metricsScope := metricsmocks.Scope{} + metricsClient.On("Scope", metrics.MatchingTaskListMgrScope).Return(&metricsScope) + registry := NewManagerRegistry(&metricsClient) + + id := mustNewIdentifierForTest(t, "domain-a", "task-list-a") + + initialMgr := newMockManagerWithID(t, ctrl, id) + updatedMgr := newMockManagerWithID(t, ctrl, id) + + metricsScope.On("UpdateGauge", metrics.TaskListManagersGauge, float64(1)).Once() + registry.Register(*id, initialMgr) + got, ok := registry.ManagerByTaskListIdentifier(*id) + require.True(t, ok) + assert.Equal(t, initialMgr, got) + + // Re-register with the same identifier should replace the manager. + metricsScope.On("UpdateGauge", metrics.TaskListManagersGauge, float64(1)).Once() + registry.Register(*id, updatedMgr) + got, ok = registry.ManagerByTaskListIdentifier(*id) + require.True(t, ok) + assert.Equal(t, updatedMgr, got) + + // Unregister should not remove a replaced/stale manager. + assert.False(t, registry.Unregister(initialMgr)) + got, ok = registry.ManagerByTaskListIdentifier(*id) + require.True(t, ok) + assert.Equal(t, updatedMgr, got) + + // Unregistering the current manager should remove the entry. + metricsScope.On("UpdateGauge", metrics.TaskListManagersGauge, float64(0)).Once() + assert.True(t, registry.Unregister(updatedMgr)) + _, ok = registry.ManagerByTaskListIdentifier(*id) + assert.False(t, ok) + + metricsClient.AssertExpectations(t) + metricsScope.AssertExpectations(t) +} + +func TestTaskListRegistry_Filters(t *testing.T) { + ctrl := gomock.NewController(t) + registry := NewManagerRegistry(metrics.NewNoopMetricsClient()) + + domainA1 := mustNewIdentifierForTest(t, "domain-a", "shared-name") + domainA2 := mustNewIdentifierForTest(t, "domain-a", "other-name") + domainB1 := mustNewIdentifierForTest(t, "domain-b", "shared-name") + + mgrA1 := newMockManagerWithID(t, ctrl, domainA1) + mgrA2 := newMockManagerWithID(t, ctrl, domainA2) + mgrB1 := newMockManagerWithID(t, ctrl, domainB1) + + registry.Register(*domainA1, mgrA1) + registry.Register(*domainA2, mgrA2) + registry.Register(*domainB1, mgrB1) + + t.Run("all managers", func(t *testing.T) { + assert.ElementsMatch(t, []Manager{mgrA1, mgrA2, mgrB1}, registry.AllManagers()) + }) + + t.Run("managers by domain", func(t *testing.T) { + assert.ElementsMatch(t, []Manager{mgrA1, mgrA2}, registry.ManagersByDomainID("domain-a")) + assert.ElementsMatch(t, []Manager{mgrB1}, registry.ManagersByDomainID("domain-b")) + assert.Empty(t, registry.ManagersByDomainID("missing-domain")) + }) + + t.Run("managers by task list name", func(t *testing.T) { + assert.ElementsMatch( + t, + []Manager{mgrA1, mgrB1}, + registry.ManagersByTaskListName("shared-name"), + ) + assert.ElementsMatch(t, []Manager{mgrA2}, registry.ManagersByTaskListName("other-name")) + assert.Empty(t, registry.ManagersByTaskListName("missing-name")) + }) +}