diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index e0e6502de4e..6db2159e9df 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -54,6 +54,7 @@ import ( "github.com/uber/cadence/service/matching/config" "github.com/uber/cadence/service/matching/event" "github.com/uber/cadence/service/matching/tasklist" + "github.com/uber/cadence/service/sharddistributor/client/executorclient" ) // If sticky poller is not seem in last 10s, we treat it as sticky worker unavailable @@ -88,8 +89,10 @@ type ( tokenSerializer common.TaskTokenSerializer logger log.Logger metricsClient metrics.Client - taskListsLock sync.RWMutex // locks mutation of taskLists - taskLists map[tasklist.Identifier]tasklist.Manager // Convert to LRU cache + taskListsLock sync.RWMutex // locks mutation of taskLists + taskLists map[tasklist.Identifier]tasklist.ShardProcessor // Convert to LRU cache + executor executorclient.Executor[tasklist.ShardProcessor] + taskListsFactory *tasklist.ShardProcessorFactory config *config.Config lockableQueryTaskMap lockableQueryTaskMap domainCache cache.DomainCache @@ -135,7 +138,6 @@ func NewEngine( isolationState isolationgroup.State, timeSource clock.TimeSource, ) Engine { - e := &matchingEngineImpl{ shutdown: make(chan struct{}), shutdownCompletion: &sync.WaitGroup{}, @@ -143,7 +145,7 @@ func NewEngine( clusterMetadata: clusterMetadata, historyService: historyService, tokenSerializer: common.NewJSONTaskTokenSerializer(), - taskLists: make(map[tasklist.Identifier]tasklist.Manager), + taskLists: make(map[tasklist.Identifier]tasklist.ShardProcessor), logger: logger.WithTags(tag.ComponentMatchingEngine), metricsClient: metricsClient, matchingClient: matchingClient, @@ -156,6 +158,7 @@ func NewEngine( timeSource: timeSource, } + e.setupTaskListFactory() e.shutdownCompletion.Add(1) go e.runMembershipChangeLoop() @@ -176,10 +179,27 @@ func (e *matchingEngineImpl) Stop() { e.shutdownCompletion.Wait() } -func (e *matchingEngineImpl) getTaskLists(maxCount int) []tasklist.Manager { +func (e *matchingEngineImpl) setupTaskListFactory() { + taskListFactory := &tasklist.ShardProcessorFactory{ + DomainCache: e.domainCache, + Logger: e.logger, + MetricsClient: e.metricsClient, + TaskManager: e.taskManager, + ClusterMetadata: e.clusterMetadata, + IsolationState: e.isolationState, + MatchingClient: e.matchingClient, + CloseCallback: e.removeTaskListManager, + Cfg: e.config, + TimeSource: e.timeSource, + CreateTime: e.timeSource.Now(), + HistoryService: e.historyService} + e.taskListsFactory = taskListFactory +} + +func (e *matchingEngineImpl) getTaskLists(maxCount int) []tasklist.ShardProcessor { e.taskListsLock.RLock() defer e.taskListsLock.RUnlock() - lists := make([]tasklist.Manager, 0, len(e.taskLists)) + lists := make([]tasklist.ShardProcessor, 0, len(e.taskLists)) count := 0 for _, tlMgr := range e.taskLists { lists = append(lists, tlMgr) @@ -202,7 +222,7 @@ func (e *matchingEngineImpl) String() string { // Returns taskListManager for a task list. If not already cached gets new range from DB and // if successful creates one. -func (e *matchingEngineImpl) getTaskListManager(taskList *tasklist.Identifier, taskListKind types.TaskListKind) (tasklist.Manager, error) { +func (e *matchingEngineImpl) getTaskListManager(taskList *tasklist.Identifier, taskListKind types.TaskListKind) (tasklist.ShardProcessor, error) { // The first check is an optimization so almost all requests will have a task list manager // and return avoiding the write lock e.taskListsLock.RLock() @@ -232,22 +252,7 @@ func (e *matchingEngineImpl) getTaskListManager(taskList *tasklist.Identifier, t ) logger.Info("Task list manager state changed", tag.LifeCycleStarting) - params := tasklist.ManagerParams{ - DomainCache: e.domainCache, - Logger: e.logger, - MetricsClient: e.metricsClient, - TaskManager: e.taskManager, - ClusterMetadata: e.clusterMetadata, - IsolationState: e.isolationState, - MatchingClient: e.matchingClient, - CloseCallback: e.removeTaskListManager, - TaskList: taskList, - TaskListKind: taskListKind, - Cfg: e.config, - TimeSource: e.timeSource, - CreateTime: e.timeSource.Now(), - HistoryService: e.historyService} - mgr, err := tasklist.NewManager(params) + mgr, err := e.taskListsFactory.NewShardProcessorWithTaskListIdentifier(taskList, taskListKind) if err != nil { e.taskListsLock.Unlock() logger.Info("Task list manager state changed", tag.LifeCycleStartFailed, tag.Error(err)) @@ -260,7 +265,7 @@ func (e *matchingEngineImpl) getTaskListManager(taskList *tasklist.Identifier, t float64(len(e.taskLists)), ) e.taskListsLock.Unlock() - err = mgr.Start() + err = mgr.Start(context.Background()) if err != nil { logger.Info("Task list manager state changed", tag.LifeCycleStartFailed, tag.Error(err)) return nil, err @@ -298,18 +303,19 @@ func (e *matchingEngineImpl) getTaskListByDomainLocked(domainID string, taskList } // For use in tests -func (e *matchingEngineImpl) updateTaskList(taskList *tasklist.Identifier, mgr tasklist.Manager) { +func (e *matchingEngineImpl) updateTaskList(taskList *tasklist.Identifier, mgr tasklist.ShardProcessor) { e.taskListsLock.Lock() defer e.taskListsLock.Unlock() e.taskLists[*taskList] = mgr } -func (e *matchingEngineImpl) removeTaskListManager(tlMgr tasklist.Manager) { +func (e *matchingEngineImpl) removeTaskListManager(tlMgr tasklist.ShardProcessor) { id := tlMgr.TaskListID() e.taskListsLock.Lock() defer e.taskListsLock.Unlock() + currentTlMgr, ok := e.taskLists[*id] - if ok && tlMgr == currentTlMgr { + if ok && currentTlMgr.String() == tlMgr.String() { delete(e.taskLists, *id) } diff --git a/service/matching/handler/engine_integration_test.go b/service/matching/handler/engine_integration_test.go index 8517012272d..164ee238d7f 100644 --- a/service/matching/handler/engine_integration_test.go +++ b/service/matching/handler/engine_integration_test.go @@ -681,6 +681,7 @@ func (s *matchingEngineSuite) SyncMatchTasks(taskType int, enableIsolation bool) // So we can get snapshots scope := tally.NewTestScope("test", nil) s.matchingEngine.metricsClient = metrics.NewClient(scope, metrics.Matching, metrics.HistogramMigration{}) + s.matchingEngine.taskListsFactory.MetricsClient = metrics.NewClient(scope, metrics.Matching, metrics.HistogramMigration{}) testParam := newTestParam(s.T(), taskType) s.taskManager.SetRangeID(testParam.TaskListID, initialRangeID) @@ -840,6 +841,7 @@ func (s *matchingEngineSuite) ConcurrentAddAndPollTasks(taskType int, workerCoun } scope := tally.NewTestScope("test", nil) s.matchingEngine.metricsClient = metrics.NewClient(scope, metrics.Matching, metrics.HistogramMigration{}) + s.matchingEngine.taskListsFactory.MetricsClient = metrics.NewClient(scope, metrics.Matching, metrics.HistogramMigration{}) const initialRangeID = 0 const rangeSize = 3 diff --git a/service/matching/handler/engine_test.go b/service/matching/handler/engine_test.go index 4f1edd8edeb..d46d24fffab 100644 --- a/service/matching/handler/engine_test.go +++ b/service/matching/handler/engine_test.go @@ -50,21 +50,21 @@ import ( func TestGetTaskListsByDomain(t *testing.T) { testCases := []struct { name string - mockSetup func(*cache.MockDomainCache, map[tasklist.Identifier]*tasklist.MockManager, map[tasklist.Identifier]*tasklist.MockManager) + mockSetup func(*cache.MockDomainCache, map[tasklist.Identifier]*tasklist.MockShardProcessor, map[tasklist.Identifier]*tasklist.MockShardProcessor) returnAllKinds bool wantErr bool want *types.GetTaskListsByDomainResponse }{ { name: "domain cache error", - mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockManager, mockStickyManagers map[tasklist.Identifier]*tasklist.MockManager) { + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockShardProcessor, mockStickyManagers map[tasklist.Identifier]*tasklist.MockShardProcessor) { mockDomainCache.EXPECT().GetDomainID("test-domain").Return("", errors.New("cache failure")) }, wantErr: true, }, { name: "success", - mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockManager, mockStickyManagers map[tasklist.Identifier]*tasklist.MockManager) { + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockShardProcessor, mockStickyManagers map[tasklist.Identifier]*tasklist.MockShardProcessor) { mockDomainCache.EXPECT().GetDomainID("test-domain").Return("test-domain-id", nil) for id, mockManager := range mockTaskListManagers { if id.GetDomainID() == "test-domain-id" { @@ -109,7 +109,7 @@ func TestGetTaskListsByDomain(t *testing.T) { { name: "success - all kinds", returnAllKinds: true, - mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockManager, mockStickyManagers map[tasklist.Identifier]*tasklist.MockManager) { + mockSetup: func(mockDomainCache *cache.MockDomainCache, mockTaskListManagers map[tasklist.Identifier]*tasklist.MockShardProcessor, mockStickyManagers map[tasklist.Identifier]*tasklist.MockShardProcessor) { mockDomainCache.EXPECT().GetDomainID("test-domain").Return("test-domain-id", nil) for id, mockManager := range mockTaskListManagers { if id.GetDomainID() == "test-domain-id" { @@ -175,25 +175,25 @@ func TestGetTaskListsByDomain(t *testing.T) { require.NoError(t, err) otherDomainTasklistID, err := tasklist.NewIdentifier("other-domain-id", "other0", 0) require.NoError(t, err) - mockDecisionTaskListManager := tasklist.NewMockManager(mockCtrl) - mockActivityTaskListManager := tasklist.NewMockManager(mockCtrl) - mockOtherDomainTaskListManager := tasklist.NewMockManager(mockCtrl) - mockTaskListManagers := map[tasklist.Identifier]*tasklist.MockManager{ + mockDecisionTaskListManager := tasklist.NewMockShardProcessor(mockCtrl) + mockActivityTaskListManager := tasklist.NewMockShardProcessor(mockCtrl) + mockOtherDomainTaskListManager := tasklist.NewMockShardProcessor(mockCtrl) + mockTaskListManagers := map[tasklist.Identifier]*tasklist.MockShardProcessor{ *decisionTasklistID: mockDecisionTaskListManager, *activityTasklistID: mockActivityTaskListManager, *otherDomainTasklistID: mockOtherDomainTaskListManager, } stickyTasklistID, err := tasklist.NewIdentifier("test-domain-id", "sticky0", 0) require.NoError(t, err) - mockStickyManager := tasklist.NewMockManager(mockCtrl) - mockStickyManagers := map[tasklist.Identifier]*tasklist.MockManager{ + mockStickyManager := tasklist.NewMockShardProcessor(mockCtrl) + mockStickyManagers := map[tasklist.Identifier]*tasklist.MockShardProcessor{ *stickyTasklistID: mockStickyManager, } tc.mockSetup(mockDomainCache, mockTaskListManagers, mockStickyManagers) engine := &matchingEngineImpl{ domainCache: mockDomainCache, - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *decisionTasklistID: mockDecisionTaskListManager, *activityTasklistID: mockActivityTaskListManager, *otherDomainTasklistID: mockOtherDomainTaskListManager, @@ -334,7 +334,7 @@ func TestCancelOutstandingPoll(t *testing.T) { testCases := []struct { name string req *types.CancelOutstandingPollRequest - mockSetup func(*tasklist.MockManager) + mockSetup func(processor *tasklist.MockShardProcessor) wantErr bool }{ { @@ -346,7 +346,7 @@ func TestCancelOutstandingPoll(t *testing.T) { }, PollerID: "test-poller-id", }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, wantErr: true, }, @@ -359,7 +359,7 @@ func TestCancelOutstandingPoll(t *testing.T) { }, PollerID: "test-poller-id", }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { mockManager.EXPECT().CancelPoller("test-poller-id") }, wantErr: false, @@ -369,12 +369,12 @@ func TestCancelOutstandingPoll(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - mockManager := tasklist.NewMockManager(mockCtrl) + mockManager := tasklist.NewMockShardProcessor(mockCtrl) tc.mockSetup(mockManager) tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 0) require.NoError(t, err) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tasklistID: mockManager, }, } @@ -445,7 +445,7 @@ func TestQueryWorkflow(t *testing.T) { name string req *types.MatchingQueryWorkflowRequest hCtx *handlerContext - mockSetup func(*tasklist.MockManager, *lockableQueryTaskMap) + mockSetup func(*tasklist.MockShardProcessor, *lockableQueryTaskMap) wantErr bool want *types.MatchingQueryWorkflowResponse }{ @@ -457,7 +457,7 @@ func TestQueryWorkflow(t *testing.T) { Name: "/__cadence_sys/invalid-tasklist-name", }, }, - mockSetup: func(mockManager *tasklist.MockManager, queryResultMap *lockableQueryTaskMap) {}, + mockSetup: func(mockManager *tasklist.MockShardProcessor, queryResultMap *lockableQueryTaskMap) {}, wantErr: true, }, { @@ -469,7 +469,7 @@ func TestQueryWorkflow(t *testing.T) { Kind: types.TaskListKindSticky.Ptr(), }, }, - mockSetup: func(mockManager *tasklist.MockManager, queryResultMap *lockableQueryTaskMap) { + mockSetup: func(mockManager *tasklist.MockShardProcessor, queryResultMap *lockableQueryTaskMap) { mockManager.EXPECT().HasPollerAfter(gomock.Any()).Return(false) }, wantErr: true, @@ -485,7 +485,7 @@ func TestQueryWorkflow(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager, queryResultMap *lockableQueryTaskMap) { + mockSetup: func(mockManager *tasklist.MockShardProcessor, queryResultMap *lockableQueryTaskMap) { mockManager.EXPECT().DispatchQueryTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("some error")) }, wantErr: true, @@ -503,7 +503,7 @@ func TestQueryWorkflow(t *testing.T) { return context.Background() }(), }, - mockSetup: func(mockManager *tasklist.MockManager, queryResultMap *lockableQueryTaskMap) { + mockSetup: func(mockManager *tasklist.MockShardProcessor, queryResultMap *lockableQueryTaskMap) { mockManager.EXPECT().DispatchQueryTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, taskID string, request *types.MatchingQueryWorkflowRequest) (*types.MatchingQueryWorkflowResponse, error) { queryResChan, ok := queryResultMap.get(taskID) if !ok { @@ -546,11 +546,11 @@ func TestQueryWorkflow(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - mockManager := tasklist.NewMockManager(mockCtrl) + mockManager := tasklist.NewMockShardProcessor(mockCtrl) tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 0) require.NoError(t, err) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tasklistID: mockManager, }, timeSource: clock.NewRealTimeSource(), @@ -734,9 +734,9 @@ func TestGetTasklistsNotOwned(t *testing.T) { tl2, _ := tasklist.NewIdentifier("", "tl2", 0) tl3, _ := tasklist.NewIdentifier("", "tl3", 0) - tl1m := tasklist.NewMockManager(ctrl) - tl2m := tasklist.NewMockManager(ctrl) - tl3m := tasklist.NewMockManager(ctrl) + tl1m := tasklist.NewMockShardProcessor(ctrl) + tl2m := tasklist.NewMockShardProcessor(ctrl) + tl3m := tasklist.NewMockShardProcessor(ctrl) resolver.EXPECT().Lookup(service.Matching, tl1.GetName()).Return(membership.NewDetailedHostInfo("", "host123", nil), nil) resolver.EXPECT().Lookup(service.Matching, tl2.GetName()).Return(membership.NewDetailedHostInfo("", "host456", nil), nil) @@ -746,7 +746,7 @@ func TestGetTasklistsNotOwned(t *testing.T) { shutdown: make(chan struct{}), membershipResolver: resolver, taskListsLock: sync.RWMutex{}, - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tl1: tl1m, *tl2: tl2m, *tl3: tl3m, @@ -774,9 +774,9 @@ func TestShutDownTasklistsNotOwned(t *testing.T) { tl2, _ := tasklist.NewIdentifier("", "tl2", 0) tl3, _ := tasklist.NewIdentifier("", "tl3", 0) - tl1m := tasklist.NewMockManager(ctrl) - tl2m := tasklist.NewMockManager(ctrl) - tl3m := tasklist.NewMockManager(ctrl) + tl1m := tasklist.NewMockShardProcessor(ctrl) + tl2m := tasklist.NewMockShardProcessor(ctrl) + tl3m := tasklist.NewMockShardProcessor(ctrl) resolver.EXPECT().Lookup(service.Matching, tl1.GetName()).Return(membership.NewDetailedHostInfo("", "host123", nil), nil) resolver.EXPECT().Lookup(service.Matching, tl2.GetName()).Return(membership.NewDetailedHostInfo("", "host456", nil), nil) @@ -786,7 +786,7 @@ func TestShutDownTasklistsNotOwned(t *testing.T) { shutdown: make(chan struct{}), membershipResolver: resolver, taskListsLock: sync.RWMutex{}, - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tl1: tl1m, *tl2: tl2m, *tl3: tl3m, @@ -821,7 +821,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { req *types.MatchingUpdateTaskListPartitionConfigRequest enableAdaptiveScaler bool hCtx *handlerContext - mockSetup func(*tasklist.MockManager) + mockSetup func(*tasklist.MockShardProcessor) expectError bool expectedError string }{ @@ -846,7 +846,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { mockManager.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ Version: 1, ReadPartitions: map[int]*types.TaskListPartition{ @@ -880,7 +880,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { mockManager.EXPECT().UpdateTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ Version: 1, ReadPartitions: map[int]*types.TaskListPartition{ @@ -915,7 +915,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "Only root partition's partition config can be updated.", @@ -941,7 +941,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "invalid partitioned task list name /__cadence_sys/test-tasklist", @@ -958,7 +958,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "Task list partition config is not set in the request.", @@ -976,7 +976,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "Only normal tasklist's partition config can be updated.", @@ -995,7 +995,7 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "Manual update is not allowed because adaptive scaler is enabled.", @@ -1007,12 +1007,12 @@ func TestUpdateTaskListPartitionConfig(t *testing.T) { mockCtrl := gomock.NewController(t) mockDomainCache := cache.NewMockDomainCache(mockCtrl) mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) - mockManager := tasklist.NewMockManager(mockCtrl) + mockManager := tasklist.NewMockShardProcessor(mockCtrl) tc.mockSetup(mockManager) tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 1) require.NoError(t, err) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tasklistID: mockManager, }, timeSource: clock.NewRealTimeSource(), @@ -1036,7 +1036,7 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { name string req *types.MatchingRefreshTaskListPartitionConfigRequest hCtx *handlerContext - mockSetup func(*tasklist.MockManager) + mockSetup func(*tasklist.MockShardProcessor) expectError bool expectedError string }{ @@ -1061,7 +1061,7 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { mockManager.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ Version: 1, ReadPartitions: map[int]*types.TaskListPartition{ @@ -1095,7 +1095,7 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { mockManager.EXPECT().RefreshTaskListPartitionConfig(gomock.Any(), &types.TaskListPartitionConfig{ Version: 1, ReadPartitions: map[int]*types.TaskListPartition{ @@ -1130,7 +1130,7 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "invalid partitioned task list name /__cadence_sys/test-tasklist", @@ -1148,7 +1148,7 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "Only normal tasklist's partition config can be updated.", @@ -1174,7 +1174,7 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { hCtx: &handlerContext{ Context: context.Background(), }, - mockSetup: func(mockManager *tasklist.MockManager) { + mockSetup: func(mockManager *tasklist.MockShardProcessor) { }, expectError: true, expectedError: "PartitionConfig must be nil for root partition.", @@ -1184,14 +1184,14 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - mockManager := tasklist.NewMockManager(mockCtrl) + mockManager := tasklist.NewMockShardProcessor(mockCtrl) tc.mockSetup(mockManager) tasklistID, err := tasklist.NewIdentifier("test-domain-id", "test-tasklist", 1) require.NoError(t, err) tasklistID2, err := tasklist.NewIdentifier("test-domain-id", "/__cadence_sys/test-tasklist/1", 1) require.NoError(t, err) engine := &matchingEngineImpl{ - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tasklistID: mockManager, *tasklistID2: mockManager, }, @@ -1213,13 +1213,13 @@ func Test_domainChangeCallback(t *testing.T) { clusters := []string{"cluster0", "cluster1"} - mockTaskListManagerGlobal1 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerGlobal2 := tasklist.NewMockManager(mockCtrl) - mockStickyTaskListManagerGlobal2 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerGlobal3 := tasklist.NewMockManager(mockCtrl) - mockStickyTaskListManagerGlobal3 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerLocal1 := tasklist.NewMockManager(mockCtrl) - mockTaskListManagerActiveActive1 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerGlobal1 := tasklist.NewMockShardProcessor(mockCtrl) + mockTaskListManagerGlobal2 := tasklist.NewMockShardProcessor(mockCtrl) + mockStickyTaskListManagerGlobal2 := tasklist.NewMockShardProcessor(mockCtrl) + mockTaskListManagerGlobal3 := tasklist.NewMockShardProcessor(mockCtrl) + mockStickyTaskListManagerGlobal3 := tasklist.NewMockShardProcessor(mockCtrl) + mockTaskListManagerLocal1 := tasklist.NewMockShardProcessor(mockCtrl) + mockTaskListManagerActiveActive1 := tasklist.NewMockShardProcessor(mockCtrl) tlIdentifierDecisionGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeDecision) tlIdentifierActivityGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeActivity) @@ -1239,7 +1239,7 @@ func Test_domainChangeCallback(t *testing.T) { failoverNotificationVersion: 1, config: defaultTestConfig(), logger: log.NewNoop(), - taskLists: map[tasklist.Identifier]tasklist.Manager{ + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{ *tlIdentifierDecisionGlobal1: mockTaskListManagerGlobal1, *tlIdentifierActivityGlobal1: mockTaskListManagerGlobal1, *tlIdentifierDecisionGlobal2: mockTaskListManagerGlobal2, @@ -1370,7 +1370,7 @@ func Test_registerDomainFailoverCallback(t *testing.T) { failoverNotificationVersion: 0, config: defaultTestConfig(), logger: log.NewNoop(), - taskLists: map[tasklist.Identifier]tasklist.Manager{}, + taskLists: map[tasklist.Identifier]tasklist.ShardProcessor{}, } engine.registerDomainFailoverCallback() diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go index b1073378d91..5bf24f88ae6 100644 --- a/service/matching/tasklist/interfaces.go +++ b/service/matching/tasklist/interfaces.go @@ -38,7 +38,7 @@ import ( type ( Manager interface { - Start() error + Start(ctx context.Context) error Stop() // AddTask adds a task to the task list. This method will first attempt a synchronous // match with a poller. When that fails, task will be written to database and later diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go index ba9a554e2c8..a9910089db7 100644 --- a/service/matching/tasklist/interfaces_mock.go +++ b/service/matching/tasklist/interfaces_mock.go @@ -214,17 +214,17 @@ func (mr *MockManagerMockRecorder) ReleaseBlockedPollers() *gomock.Call { } // Start mocks base method. -func (m *MockManager) Start() error { +func (m *MockManager) Start(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Start") + ret := m.ctrl.Call(m, "Start", ctx) ret0, _ := ret[0].(error) return ret0 } // Start indicates an expected call of Start. -func (mr *MockManagerMockRecorder) Start() *gomock.Call { +func (mr *MockManagerMockRecorder) Start(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockManager)(nil).Start)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockManager)(nil).Start), ctx) } // Stop mocks base method. @@ -786,17 +786,17 @@ func (mr *MockShardProcessorMockRecorder) SetShardStatus(arg0 any) *gomock.Call } // Start mocks base method. -func (m *MockShardProcessor) Start() error { +func (m *MockShardProcessor) Start(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Start") + ret := m.ctrl.Call(m, "Start", ctx) ret0, _ := ret[0].(error) return ret0 } // Start indicates an expected call of Start. -func (mr *MockShardProcessorMockRecorder) Start() *gomock.Call { +func (mr *MockShardProcessorMockRecorder) Start(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockShardProcessor)(nil).Start)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockShardProcessor)(nil).Start), ctx) } // Stop mocks base method. diff --git a/service/matching/tasklist/shard_processor_factory.go b/service/matching/tasklist/shard_processor_factory.go new file mode 100644 index 00000000000..3f535512188 --- /dev/null +++ b/service/matching/tasklist/shard_processor_factory.go @@ -0,0 +1,80 @@ +package tasklist + +import ( + "time" + + "github.com/uber/cadence/client/history" + "github.com/uber/cadence/client/matching" + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/cluster" + "github.com/uber/cadence/common/isolationgroup" + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/types" + "github.com/uber/cadence/service/matching/config" +) + +// ShardProcessorFactory is a generic factory for creating ShardProcessor instances. +type ShardProcessorFactory struct { + DomainCache cache.DomainCache + Logger log.Logger + MetricsClient metrics.Client + TaskManager persistence.TaskManager + ClusterMetadata cluster.Metadata + IsolationState isolationgroup.State + MatchingClient matching.Client + CloseCallback func(processor ShardProcessor) + Cfg *config.Config + TimeSource clock.TimeSource + CreateTime time.Time + HistoryService history.Client +} + +func (spf ShardProcessorFactory) NewShardProcessor(shardID string) (ShardProcessor, error) { + name, err := newTaskListName(shardID) + if err != nil { + return nil, err + } + identifier := &Identifier{ + qualifiedTaskListName: name, + } + params := ManagerParams{ + DomainCache: spf.DomainCache, + Logger: spf.Logger, + MetricsClient: spf.MetricsClient, + TaskManager: spf.TaskManager, + ClusterMetadata: spf.ClusterMetadata, + IsolationState: spf.IsolationState, + MatchingClient: spf.MatchingClient, + CloseCallback: spf.CloseCallback, + TaskList: identifier, + TaskListKind: 0, + Cfg: spf.Cfg, + TimeSource: spf.TimeSource, + CreateTime: spf.TimeSource.Now(), + HistoryService: spf.HistoryService, + } + return NewShardProcessor(params) +} + +func (spf ShardProcessorFactory) NewShardProcessorWithTaskListIdentifier(taskListID *Identifier, taskListKind types.TaskListKind) (ShardProcessor, error) { + params := ManagerParams{ + DomainCache: spf.DomainCache, + Logger: spf.Logger, + MetricsClient: spf.MetricsClient, + TaskManager: spf.TaskManager, + ClusterMetadata: spf.ClusterMetadata, + IsolationState: spf.IsolationState, + MatchingClient: spf.MatchingClient, + CloseCallback: spf.CloseCallback, + TaskList: taskListID, + TaskListKind: taskListKind, + Cfg: spf.Cfg, + TimeSource: spf.TimeSource, + CreateTime: spf.TimeSource.Now(), + HistoryService: spf.HistoryService, + } + return NewShardProcessor(params) +} diff --git a/service/matching/tasklist/shard_processor_test.go b/service/matching/tasklist/shard_processor_test.go index 7259543be61..b149beb07f2 100644 --- a/service/matching/tasklist/shard_processor_test.go +++ b/service/matching/tasklist/shard_processor_test.go @@ -46,7 +46,7 @@ func paramsForTaskListManager(t *testing.T, taskListID *Identifier, taskListKind clusterMetadata, deps.mockIsolationState, deps.mockMatchingClient, - func(Manager) {}, + func(ShardProcessor) {}, taskListID, taskListKind, cfg, diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index 030acbc22d8..0f0aaafa38b 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -94,7 +94,7 @@ type ( ClusterMetadata cluster.Metadata IsolationState isolationgroup.State MatchingClient matching.Client - CloseCallback func(Manager) + CloseCallback func(ShardProcessor) TaskList *Identifier TaskListKind types.TaskListKind Cfg *config.Config @@ -140,7 +140,7 @@ type ( stopWG sync.WaitGroup stopped int32 stoppedLock sync.RWMutex - closeCallback func(Manager) + closeCallback func(ShardProcessor) throttleRetry *backoff.ThrottleRetry qpsTracker stats.QPSTrackerGroup @@ -257,7 +257,7 @@ func NewManager(p ManagerParams) (Manager, error) { // Starts reading pump for the given task list. // The pump fills up taskBuffer from persistence. -func (c *taskListManagerImpl) Start() error { +func (c *taskListManagerImpl) Start(ctx context.Context) error { defer c.startWG.Done() if !c.taskListID.IsRoot() && c.taskListKind == types.TaskListKindNormal { @@ -316,7 +316,10 @@ func (c *taskListManagerImpl) Stop() { if !atomic.CompareAndSwapInt32(&c.stopped, 0, 1) { return } - c.closeCallback(c) + sp := &shardProcessorImpl{ + Manager: c, + } + c.closeCallback(sp) if c.adaptiveScaler != nil { c.adaptiveScaler.Stop() } diff --git a/service/matching/tasklist/task_list_manager_test.go b/service/matching/tasklist/task_list_manager_test.go index 906e34fb384..adc215abf3b 100644 --- a/service/matching/tasklist/task_list_manager_test.go +++ b/service/matching/tasklist/task_list_manager_test.go @@ -96,7 +96,7 @@ func setupMocksForTaskListManager(t *testing.T, taskListID *Identifier, taskList clusterMetadata, deps.mockIsolationState, deps.mockMatchingClient, - func(Manager) {}, + func(ShardProcessor) {}, taskListID, taskListKind, config, @@ -248,7 +248,7 @@ func createTestTaskListManagerWithConfig(t *testing.T, logger log.Logger, contro cluster.GetTestClusterMetadata(true), mockIsolationState, mockMatchingClient, - func(Manager) {}, + func(ShardProcessor) {}, tlID, types.TaskListKindNormal, cfg, @@ -424,7 +424,7 @@ func TestCheckIdleTaskList(t *testing.T) { t.Run("Idle task-list", func(t *testing.T) { ctrl := gomock.NewController(t) tlm := createTestTaskListManagerWithConfig(t, testlogger.New(t), ctrl, cfg, clock.NewRealTimeSource()) - require.NoError(t, tlm.Start()) + require.NoError(t, tlm.Start(context.Background())) require.EqualValues(t, 0, atomic.LoadInt32(&tlm.stopped), "idle check interval had not passed yet") time.Sleep(20 * time.Millisecond) @@ -434,7 +434,7 @@ func TestCheckIdleTaskList(t *testing.T) { t.Run("Active poll-er", func(t *testing.T) { ctrl := gomock.NewController(t) tlm := createTestTaskListManagerWithConfig(t, testlogger.New(t), ctrl, cfg, clock.NewRealTimeSource()) - require.NoError(t, tlm.Start()) + require.NoError(t, tlm.Start(context.Background())) ctx, cancel := context.WithTimeout(context.Background(), time.Second) _, _ = tlm.GetTask(ctx, nil) @@ -467,7 +467,7 @@ func TestCheckIdleTaskList(t *testing.T) { ctrl := gomock.NewController(t) tlm := createTestTaskListManagerWithConfig(t, testlogger.New(t), ctrl, cfg, clock.NewRealTimeSource()) - require.NoError(t, tlm.Start()) + require.NoError(t, tlm.Start(context.Background())) time.Sleep(8 * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -493,7 +493,7 @@ func TestAddTaskStandby(t *testing.T) { cfg.IdleTasklistCheckInterval = dynamicproperties.GetDurationPropertyFnFilteredByTaskListInfo(10 * time.Millisecond) tlm := createTestTaskListManagerWithConfig(t, logger, controller, cfg, clock.NewMockedTimeSource()) - require.NoError(t, tlm.Start()) + require.NoError(t, tlm.Start(context.Background())) // stop taskWriter so that we can check if there's any call to it // otherwise the task persist process is async and hard to test @@ -842,7 +842,7 @@ func TestTaskWriterShutdown(t *testing.T) { controller := gomock.NewController(t) logger := testlogger.New(t) tlm := createTestTaskListManager(t, logger, controller) - err := tlm.Start() + err := tlm.Start(context.Background()) assert.NoError(t, err) // stop the task writer explicitly @@ -894,7 +894,7 @@ func TestTaskListManagerGetTaskBatch(t *testing.T) { cluster.GetTestClusterMetadata(true), mockIsolationState, matching.NewMockClient(controller), - func(Manager) {}, + func(ShardProcessor) {}, taskListID, types.TaskListKindNormal, cfg, @@ -905,7 +905,7 @@ func TestTaskListManagerGetTaskBatch(t *testing.T) { tlMgr, err := NewManager(params) assert.NoError(t, err) tlm := tlMgr.(*taskListManagerImpl) - err = tlm.Start() + err = tlm.Start(context.Background()) assert.NoError(t, err) // add taskCount tasks @@ -966,7 +966,7 @@ func TestTaskListManagerGetTaskBatch(t *testing.T) { cluster.GetTestClusterMetadata(true), mockIsolationState, matching.NewMockClient(controller), - func(Manager) {}, + func(ShardProcessor) {}, taskListID, types.TaskListKindNormal, cfg, @@ -977,7 +977,7 @@ func TestTaskListManagerGetTaskBatch(t *testing.T) { tlMgr, err = NewManager(newParams) assert.NoError(t, err) tlm = tlMgr.(*taskListManagerImpl) - err = tlm.Start() + err = tlm.Start(context.Background()) assert.NoError(t, err) for i := int64(0); i < rangeSize; i++ { task, err := tlm.GetTask(context.Background(), nil) @@ -1026,7 +1026,7 @@ func TestTaskListReaderPumpAdvancesAckLevelAfterEmptyReads(t *testing.T) { cluster.GetTestClusterMetadata(true), mockIsolationState, matching.NewMockClient(controller), - func(Manager) {}, + func(ShardProcessor) {}, taskListID, types.TaskListKindNormal, cfg, @@ -1043,7 +1043,7 @@ func TestTaskListReaderPumpAdvancesAckLevelAfterEmptyReads(t *testing.T) { tlm.taskWriter.renewLeaseWithRetry() } - err = tlm.Start() // this call will also renew lease + err = tlm.Start(context.Background()) // this call will also renew lease require.NoError(t, err) defer tlm.Stop() @@ -1174,7 +1174,7 @@ func TestTaskExpiryAndCompletion(t *testing.T) { cluster.GetTestClusterMetadata(true), mockIsolationState, matching.NewMockClient(controller), - func(Manager) {}, + func(ShardProcessor) {}, taskListID, types.TaskListKindNormal, cfg, @@ -1185,7 +1185,7 @@ func TestTaskExpiryAndCompletion(t *testing.T) { tlMgr, err := NewManager(params) assert.NoError(t, err) tlm := tlMgr.(*taskListManagerImpl) - err = tlm.Start() + err = tlm.Start(context.Background()) assert.NoError(t, err) for i := int64(0); i < taskCount; i++ { scheduleID := i * 3 @@ -1253,7 +1253,7 @@ func TestTaskListManagerImpl_HasPollerAfter(t *testing.T) { controller := gomock.NewController(t) logger := testlogger.New(t) tlm := createTestTaskListManager(t, logger, controller) - err := tlm.Start() + err := tlm.Start(context.Background()) assert.NoError(t, err) if tc.prepareManager != nil { @@ -1661,7 +1661,7 @@ func TestManagerStart_RootPartition(t *testing.T) { WritePartitions: partitions(2), }, }).Return(&types.MatchingRefreshTaskListPartitionConfigResponse{}, nil) - assert.NoError(t, tlm.Start()) + assert.NoError(t, tlm.Start(context.Background())) assert.Equal(t, &types.TaskListPartitionConfig{Version: 1, ReadPartitions: partitions(2), WritePartitions: partitions(2)}, tlm.TaskListPartitionConfig()) tlm.stopWG.Wait() } @@ -1703,7 +1703,7 @@ func TestManagerStart_NonRootPartition(t *testing.T) { RangeID: 0, }, }, nil) - assert.NoError(t, tlm.Start()) + assert.NoError(t, tlm.Start(context.Background())) assert.Equal(t, &types.TaskListPartitionConfig{ Version: 1, ReadPartitions: partitions(3), diff --git a/service/sharddistributor/canary/processor/shardprocessor.go b/service/sharddistributor/canary/processor/shardprocessor.go index 0fe75769c50..3f399c6d47b 100644 --- a/service/sharddistributor/canary/processor/shardprocessor.go +++ b/service/sharddistributor/canary/processor/shardprocessor.go @@ -49,10 +49,11 @@ func (p *ShardProcessor) GetShardReport() executorclient.ShardReport { } // Start implements executorclient.ShardProcessor. -func (p *ShardProcessor) Start(ctx context.Context) { +func (p *ShardProcessor) Start(ctx context.Context) error { p.logger.Info("Starting shard processor", zap.String("shardID", p.shardID)) p.goRoutineWg.Add(1) go p.process(ctx) + return nil } // Stop implements executorclient.ShardProcessor. diff --git a/service/sharddistributor/canary/processorephemeral/shardprocessor.go b/service/sharddistributor/canary/processorephemeral/shardprocessor.go index 9e09bbe271c..6cc5821c3b5 100644 --- a/service/sharddistributor/canary/processorephemeral/shardprocessor.go +++ b/service/sharddistributor/canary/processorephemeral/shardprocessor.go @@ -59,10 +59,11 @@ func (p *ShardProcessor) GetShardReport() executorclient.ShardReport { } // Start implements executorclient.ShardProcessor. -func (p *ShardProcessor) Start(ctx context.Context) { +func (p *ShardProcessor) Start(ctx context.Context) error { p.logger.Info("Starting shard processor", zap.String("shardID", p.shardID)) p.goRoutineWg.Add(1) go p.process(ctx) + return nil } // Stop implements executorclient.ShardProcessor. diff --git a/service/sharddistributor/client/executorclient/client.go b/service/sharddistributor/client/executorclient/client.go index 593d9fdf73c..1411007b2c5 100644 --- a/service/sharddistributor/client/executorclient/client.go +++ b/service/sharddistributor/client/executorclient/client.go @@ -30,7 +30,7 @@ type ShardReport struct { } type ShardProcessor interface { - Start(ctx context.Context) + Start(ctx context.Context) error Stop() GetShardReport() ShardReport } diff --git a/service/sharddistributor/client/executorclient/interface_mock.go b/service/sharddistributor/client/executorclient/interface_mock.go index 8b5e236b7d9..67195880fe1 100644 --- a/service/sharddistributor/client/executorclient/interface_mock.go +++ b/service/sharddistributor/client/executorclient/interface_mock.go @@ -57,9 +57,11 @@ func (mr *MockShardProcessorMockRecorder) GetShardReport() *gomock.Call { } // Start mocks base method. -func (m *MockShardProcessor) Start(ctx context.Context) { +func (m *MockShardProcessor) Start(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Start", ctx) + ret := m.ctrl.Call(m, "Start", ctx) + ret0, _ := ret[0].(error) + return ret0 } // Start indicates an expected call of Start.