diff --git a/common/metrics/defs.go b/common/metrics/defs.go index 557f3bc571d..2b1b91807f1 100644 --- a/common/metrics/defs.go +++ b/common/metrics/defs.go @@ -3538,7 +3538,7 @@ var MetricDefs = map[ServiceIdx]map[MetricIdx]metricDefinition{ ReplicationTaskLatency: {metricName: "replication_task_latency", metricType: Timer}, ExponentialReplicationTaskLatency: {metricName: "replication_task_latency_ns", metricType: Histogram, exponentialBuckets: Mid1ms24h}, ExponentialReplicationTaskFetchLatency: {metricName: "replication_task_fetch_latency_ns", metricType: Histogram, exponentialBuckets: Mid1ms24h}, - ReplicationTasksFetchedSize: {metricName: "replication_tasks_fetched_size", metricType: Gauge}, + ReplicationTasksFetchedSize: {metricName: "replication_tasks_fetched_size", metricType: Histogram, buckets: ResponseRowSizeBuckets}, MutableStateChecksumMismatch: {metricName: "mutable_state_checksum_mismatch", metricType: Counter}, MutableStateChecksumInvalidated: {metricName: "mutable_state_checksum_invalidated", metricType: Counter}, FailoverMarkerCount: {metricName: "failover_marker_count", metricType: Counter}, diff --git a/service/history/engine/testdata/engine_for_tests.go b/service/history/engine/testdata/engine_for_tests.go index 711a9b55e7c..4be52fcbf1c 100644 --- a/service/history/engine/testdata/engine_for_tests.go +++ b/service/history/engine/testdata/engine_for_tests.go @@ -114,7 +114,7 @@ func NewEngineForTest(t *testing.T, newEngineFn NewEngineFn) *EngineForTest { // TODO: this should probably return another cluster name, not current replicationTaskFetcher.EXPECT().GetSourceCluster().Return(constants.TestClusterMetadata.GetCurrentClusterName()).AnyTimes() replicationTaskFetcher.EXPECT().GetRateLimiter().Return(quotas.NewDynamicRateLimiter(func() float64 { return 100 })).AnyTimes() - replicationTaskFetcher.EXPECT().GetRequestChan().Return(nil).AnyTimes() + replicationTaskFetcher.EXPECT().GetRequestChan(gomock.Any()).Return(nil).AnyTimes() replicatonTaskFetchers.EXPECT().GetFetchers().Return([]replication.TaskFetcher{replicationTaskFetcher}).AnyTimes() failoverCoordinator := failover.NewMockCoordinator(controller) diff --git a/service/history/replication/task_fetcher.go b/service/history/replication/task_fetcher.go index 70dacce1530..357d64e64c1 100644 --- a/service/history/replication/task_fetcher.go +++ b/service/history/replication/task_fetcher.go @@ -54,7 +54,7 @@ type ( common.Daemon GetSourceCluster() string - GetRequestChan() chan<- *request + GetRequestChan(shardID int) chan<- *request GetRateLimiter() quotas.Limiter } @@ -76,7 +76,7 @@ type ( remotePeer admin.Client rateLimiter quotas.Limiter timeSource clock.TimeSource - requestChan chan *request + requestChan []chan *request ctx context.Context cancelCtx context.CancelFunc wg sync.WaitGroup @@ -167,6 +167,12 @@ func newReplicationTaskFetcher( metricsClient metrics.Client, ) TaskFetcher { ctx, cancel := context.WithCancel(context.Background()) + + requestChan := make([]chan *request, config.ReplicationTaskFetcherParallelism()) + for i := 0; i < config.ReplicationTaskFetcherParallelism(); i++ { + requestChan[i] = make(chan *request, requestChanBufferSize) + } + fetcher := &taskFetcherImpl{ status: common.DaemonStatusInitialized, config: config, @@ -177,7 +183,7 @@ func newReplicationTaskFetcher( sourceCluster: sourceCluster, rateLimiter: quotas.NewDynamicRateLimiter(config.ReplicationTaskProcessorHostQPS.AsFloat64()), timeSource: clock.NewRealTimeSource(), - requestChan: make(chan *request, requestChanBufferSize), + requestChan: requestChan, ctx: ctx, cancelCtx: cancel, } @@ -191,11 +197,12 @@ func (f *taskFetcherImpl) Start() { return } - // NOTE: we have never run production service with ReplicationTaskFetcherParallelism larger than 1, - // the behavior is undefined if we do so. We should consider making this config a boolean. + // NOTE: ReplicationTaskFetcherParallelism > 1 is now supported. Each fetcher goroutine handles a subset of shards + // (distributed via shardID % parallelism) and runs its own fetch cycle independently. for i := 0; i < f.config.ReplicationTaskFetcherParallelism(); i++ { + i := i f.wg.Add(1) - go f.fetchTasks() + go f.fetchTasks(i) } f.logger.Info("Replication task fetcher started.", tag.Counter(f.config.ReplicationTaskFetcherParallelism())) } @@ -215,13 +222,8 @@ func (f *taskFetcherImpl) Stop() { } // fetchTasks collects getReplicationTasks request from shards and send out aggregated request to source frontend. -func (f *taskFetcherImpl) fetchTasks() { - startTime := f.timeSource.Now() +func (f *taskFetcherImpl) fetchTasks(chanIdx int) { defer f.wg.Done() - defer func() { - totalLatency := f.timeSource.Now().Sub(startTime) - f.metricsScope.ExponentialHistogram(metrics.ExponentialReplicationTaskFetchLatency, totalLatency) - }() timer := f.timeSource.NewTimer(backoff.JitDuration( f.config.ReplicationTaskFetcherAggregationInterval(), @@ -232,7 +234,7 @@ func (f *taskFetcherImpl) fetchTasks() { requestByShard := make(map[int32]*request) for { select { - case request := <-f.requestChan: + case request := <-f.requestChan[chanIdx]: // Here we only add the request to map. We will wait until timer fires to send the request to remote. if req, ok := requestByShard[request.token.GetShardID()]; ok && req != request { // since this replication task fetcher is per host and replication task processor is per shard @@ -268,6 +270,12 @@ func (f *taskFetcherImpl) fetchTasks() { } func (f *taskFetcherImpl) fetchAndDistributeTasks(requestByShard map[int32]*request) error { + startTime := f.timeSource.Now() + defer func() { + fetchLatency := f.timeSource.Now().Sub(startTime) + f.metricsScope.ExponentialHistogram(metrics.ExponentialReplicationTaskFetchLatency, fetchLatency) + }() + if len(requestByShard) == 0 { // We don't receive tasks from previous fetch so processors are all sleeping. f.logger.Debug("Skip fetching as no processor is asking for tasks.") @@ -289,7 +297,7 @@ func (f *taskFetcherImpl) fetchAndDistributeTasks(requestByShard map[int32]*requ for _, messages := range messagesByShard { totalTasks += len(messages.ReplicationTasks) } - f.metricsScope.UpdateGauge(metrics.ReplicationTasksFetchedSize, float64(totalTasks)) + f.metricsScope.RecordHistogramValue(metrics.ReplicationTasksFetchedSize, float64(totalTasks)) f.logger.Debug("Successfully fetched replication tasks.", tag.Counter(len(messagesByShard))) for shardID, tasks := range messagesByShard { @@ -329,8 +337,10 @@ func (f *taskFetcherImpl) GetSourceCluster() string { } // GetRequestChan returns the request chan for the fetcher -func (f *taskFetcherImpl) GetRequestChan() chan<- *request { - return f.requestChan +func (f *taskFetcherImpl) GetRequestChan(shardID int) chan<- *request { + chanIdx := shardID % f.config.ReplicationTaskFetcherParallelism() + + return f.requestChan[chanIdx] } // GetRateLimiter returns the host level rate limiter for the fetcher diff --git a/service/history/replication/task_fetcher_mock.go b/service/history/replication/task_fetcher_mock.go index 96d007d4e6b..ed6f0f5cbea 100644 --- a/service/history/replication/task_fetcher_mock.go +++ b/service/history/replication/task_fetcher_mock.go @@ -56,17 +56,17 @@ func (mr *MockTaskFetcherMockRecorder) GetRateLimiter() *gomock.Call { } // GetRequestChan mocks base method. -func (m *MockTaskFetcher) GetRequestChan() chan<- *request { +func (m *MockTaskFetcher) GetRequestChan(shardID int) chan<- *request { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRequestChan") + ret := m.ctrl.Call(m, "GetRequestChan", shardID) ret0, _ := ret[0].(chan<- *request) return ret0 } // GetRequestChan indicates an expected call of GetRequestChan. -func (mr *MockTaskFetcherMockRecorder) GetRequestChan() *gomock.Call { +func (mr *MockTaskFetcherMockRecorder) GetRequestChan(shardID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestChan", reflect.TypeOf((*MockTaskFetcher)(nil).GetRequestChan)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestChan", reflect.TypeOf((*MockTaskFetcher)(nil).GetRequestChan), shardID) } // GetSourceCluster mocks base method. diff --git a/service/history/replication/task_fetcher_test.go b/service/history/replication/task_fetcher_test.go index d23b1eb511b..4e1610b70fa 100644 --- a/service/history/replication/task_fetcher_test.go +++ b/service/history/replication/task_fetcher_test.go @@ -218,7 +218,7 @@ func (s *taskFetcherSuite) TestLifecycle() { s.taskFetcher.Start() defer s.taskFetcher.Stop() - requestChan := s.taskFetcher.GetRequestChan() + requestChan := s.taskFetcher.GetRequestChan(0) // send 3 replication requests to the fetcher requestChan <- req0 requestChan <- req1 @@ -278,3 +278,55 @@ func TestTaskFetchers(t *testing.T) { fetchers.Start() fetchers.Stop() } + +func TestTaskFetcherParallelism(t *testing.T) { + defer goleak.VerifyNone(t) + logger := testlogger.New(t) + cfg := config.NewForTest() + parallelism := 4 + cfg.ReplicationTaskFetcherParallelism = dynamicproperties.GetIntPropertyFn(parallelism) + + ctrl := gomock.NewController(t) + mockAdminClient := admin.NewMockClient(ctrl) + + fetcher := newReplicationTaskFetcher( + logger, + "standby", + "active", + cfg, + mockAdminClient, + metrics.NewNoopMetricsClient(), + ).(*taskFetcherImpl) + + // Test 1: Verify correct number of channels created + assert.Equal(t, parallelism, len(fetcher.requestChan), "Should create 4 request channels") + + // Test 2: Verify shard-to-channel mapping + chan0 := fetcher.GetRequestChan(0) + chan1 := fetcher.GetRequestChan(1) + chan4 := fetcher.GetRequestChan(4) // 4 % 4 = 0, should be same as chan0 + chan5 := fetcher.GetRequestChan(5) // 5 % 4 = 1, should be same as chan1 + + assert.Equal(t, chan0, chan4, "Shards 0 and 4 should map to same channel (0 % 4 == 4 % 4)") + assert.Equal(t, chan1, chan5, "Shards 1 and 5 should map to same channel (1 % 4 == 5 % 4)") + assert.NotEqual(t, chan0, chan1, "Different channels should be different") + + // Test 3: Start fetcher and verify WaitGroup is properly incremented + fetcher.Start() + + // The WaitGroup counter should now be 4 (one per goroutine) + // We can verify this by calling Stop() which waits on the WaitGroup + // If it hangs or times out, the goroutines weren't started correctly + done := make(chan bool) + go func() { + fetcher.Stop() + done <- true + }() + + select { + case <-done: + // Success - all goroutines exited cleanly + case <-time.After(11 * time.Second): + t.Fatal("Stop() timed out - goroutines may not have been started correctly") + } +} diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index 4d2e7db086a..d5d4c3de55f 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -149,7 +149,7 @@ func NewTaskProcessor( taskRetryPolicy: taskRetryPolicy, dlqRetryPolicy: dlqRetryPolicy, noTaskRetrier: noTaskRetrier, - requestChan: taskFetcher.GetRequestChan(), + requestChan: taskFetcher.GetRequestChan(shardID), syncShardChan: make(chan *types.SyncShardStatus, 1), done: make(chan struct{}), lastProcessedMessageID: constants.EmptyMessageID, diff --git a/service/history/replication/task_processor_test.go b/service/history/replication/task_processor_test.go index d2e9bda5e4a..2dc9b5003e6 100644 --- a/service/history/replication/task_processor_test.go +++ b/service/history/replication/task_processor_test.go @@ -93,7 +93,7 @@ func (f fakeTaskFetcher) Stop() {} func (f fakeTaskFetcher) GetSourceCluster() string { return f.sourceCluster } -func (f fakeTaskFetcher) GetRequestChan() chan<- *request { +func (f fakeTaskFetcher) GetRequestChan(shardID int) chan<- *request { return f.requestChan } func (f fakeTaskFetcher) GetRateLimiter() quotas.Limiter {