diff --git a/common/config/config.go b/common/config/config.go index 904a2f7eb8e..a9e510d8ad0 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -658,8 +658,9 @@ type ( } LeaderProcess struct { - Period time.Duration `yaml:"period"` - HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + Period time.Duration `yaml:"period"` + HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + ShardStatsTTL time.Duration `yaml:"shardStatsTTL"` } ) diff --git a/config/development.yaml b/config/development.yaml index d485c7ac4ba..1f34cb3d2a0 100644 --- a/config/development.yaml +++ b/config/development.yaml @@ -186,3 +186,4 @@ shardDistribution: process: period: 1s heartbeatTTL: 2s + shardStatsTTL: 60s diff --git a/service/sharddistributor/config/config.go b/service/sharddistributor/config/config.go index 79106ffc635..4ae80e489fc 100644 --- a/service/sharddistributor/config/config.go +++ b/service/sharddistributor/config/config.go @@ -79,8 +79,9 @@ type ( } LeaderProcess struct { - Period time.Duration `yaml:"period"` - HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + Period time.Duration `yaml:"period"` + HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + ShardStatsTTL time.Duration `yaml:"shardStatsTTL"` } ) @@ -97,6 +98,10 @@ const ( MigrationModeONBOARDED = "onboarded" ) +const ( + DefaultShardStatsTTL = time.Minute +) + // ConfigMode maps string migration mode values to types.MigrationMode var ConfigMode = map[string]types.MigrationMode{ MigrationModeINVALID: types.MigrationModeINVALID, diff --git a/service/sharddistributor/leader/process/processor.go b/service/sharddistributor/leader/process/processor.go index 901e6822ef7..9179357f07c 100644 --- a/service/sharddistributor/leader/process/processor.go +++ b/service/sharddistributor/leader/process/processor.go @@ -76,12 +76,15 @@ func NewProcessorFactory( timeSource clock.TimeSource, cfg config.ShardDistribution, ) Factory { - if cfg.Process.Period == 0 { + if cfg.Process.Period <= 0 { cfg.Process.Period = _defaultPeriod } - if cfg.Process.HeartbeatTTL == 0 { + if cfg.Process.HeartbeatTTL <= 0 { cfg.Process.HeartbeatTTL = _defaultHearbeatTTL } + if cfg.Process.ShardStatsTTL <= 0 { + cfg.Process.ShardStatsTTL = config.DefaultShardStatsTTL + } return &processorFactory{ logger: logger, @@ -237,7 +240,7 @@ func (p *namespaceProcessor) runShardStatsCleanupLoop(ctx context.Context) { continue } staleShardStats := p.identifyStaleShardStats(namespaceState) - if len(staleShardStats) > 0 { + if len(staleShardStats) == 0 { // No stale shard stats to delete continue } @@ -267,7 +270,7 @@ func (p *namespaceProcessor) identifyStaleExecutors(namespaceState *store.Namesp func (p *namespaceProcessor) identifyStaleShardStats(namespaceState *store.NamespaceState) []string { activeShards := make(map[string]struct{}) now := p.timeSource.Now().Unix() - shardStatsTTL := int64(p.cfg.HeartbeatTTL.Seconds()) + shardStatsTTL := int64(p.cfg.ShardStatsTTL.Seconds()) // 1. build set of active executors diff --git a/service/sharddistributor/leader/process/processor_test.go b/service/sharddistributor/leader/process/processor_test.go index 4027e9a496b..206c0ebf3c3 100644 --- a/service/sharddistributor/leader/process/processor_test.go +++ b/service/sharddistributor/leader/process/processor_test.go @@ -44,8 +44,9 @@ func setupProcessorTest(t *testing.T, namespaceType string) *testDependencies { mockedClock, config.ShardDistribution{ Process: config.LeaderProcess{ - Period: time.Second, - HeartbeatTTL: time.Second, + Period: time.Second, + HeartbeatTTL: time.Second, + ShardStatsTTL: 10 * time.Second, }, }, ), @@ -259,7 +260,7 @@ func TestCleanupStaleShardStats(t *testing.T) { shardStats := map[string]store.ShardStatistics{ "shard-1": {SmoothedLoad: 1.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, "shard-2": {SmoothedLoad: 2.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, - "shard-3": {SmoothedLoad: 3.0, LastUpdateTime: now.Add(-2 * time.Second).Unix(), LastMoveTime: now.Add(-2 * time.Second).Unix()}, + "shard-3": {SmoothedLoad: 3.0, LastUpdateTime: now.Add(-11 * time.Second).Unix(), LastMoveTime: now.Add(-11 * time.Second).Unix()}, } namespaceState := &store.NamespaceState{ diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore.go b/service/sharddistributor/store/etcd/executorstore/etcdstore.go index 9bd2b465d61..2bd88baae4e 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "strconv" "time" @@ -30,13 +31,19 @@ var ( ) type executorStoreImpl struct { - client *clientv3.Client - prefix string - logger log.Logger - shardCache *shardcache.ShardToExecutorCache - timeSource clock.TimeSource + client *clientv3.Client + prefix string + logger log.Logger + shardCache *shardcache.ShardToExecutorCache + timeSource clock.TimeSource + maxStatsPersistIntervalSeconds int64 // Max interval (seconds) before we force a shard-stat persist. } +// Constants for gating shard statistics writes to reduce etcd load. +const ( + shardStatsEpsilon = 0.05 +) + // shardStatisticsUpdate holds the staged statistics for a shard so we can write them // to etcd after the main AssignShards transaction commits. type shardStatisticsUpdate struct { @@ -88,12 +95,18 @@ func NewStore(p ExecutorStoreParams) (store.Store, error) { timeSource = clock.NewRealTimeSource() } + shardStatsTTL := p.Cfg.Process.ShardStatsTTL + if shardStatsTTL <= 0 { + shardStatsTTL = config.DefaultShardStatsTTL + } + store := &executorStoreImpl{ - client: etcdClient, - prefix: etcdCfg.Prefix, - logger: p.Logger, - shardCache: shardCache, - timeSource: timeSource, + client: etcdClient, + prefix: etcdCfg.Prefix, + logger: p.Logger, + shardCache: shardCache, + timeSource: timeSource, + maxStatsPersistIntervalSeconds: deriveStatsPersistInterval(shardStatsTTL), } p.Lifecycle.Append(fx.StartStopHook(store.Start, store.Stop)) @@ -153,9 +166,134 @@ func (s *executorStoreImpl) RecordHeartbeat(ctx context.Context, namespace, exec if err != nil { return fmt.Errorf("record heartbeat: %w", err) } + + s.recordShardStatistics(ctx, namespace, executorID, request.ReportedShards) + return nil } +func deriveStatsPersistInterval(shardStatsTTL time.Duration) int64 { + ttlSeconds := int64(shardStatsTTL.Seconds()) + return max(1, ttlSeconds-1) +} + +func (s *executorStoreImpl) recordShardStatistics(ctx context.Context, namespace, executorID string, reported map[string]*types.ShardStatusReport) { + if len(reported) == 0 { + return + } + + now := s.timeSource.Now().Unix() + + for shardID, report := range reported { + if report == nil { + s.logger.Warn("empty report; skipping EWMA update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + ) + continue + } + + load := report.ShardLoad + if math.IsNaN(load) || math.IsInf(load, 0) { + s.logger.Warn( + "invalid shard load reported; skipping EWMA update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + ) + continue + } + + shardStatsKey, err := etcdkeys.BuildShardKey(s.prefix, namespace, shardID, etcdkeys.ShardStatisticsKey) + if err != nil { + s.logger.Warn( + "failed to build shard statistics key from heartbeat", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + + statsResp, err := s.client.Get(ctx, shardStatsKey) + if err != nil { + s.logger.Warn( + "failed to read shard statistics for heartbeat update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + + var stats store.ShardStatistics + if len(statsResp.Kvs) > 0 { + err := common.DecompressAndUnmarshal(statsResp.Kvs[0].Value, &stats) + if err != nil { + s.logger.Warn( + "failed to unmarshal shard statistics for heartbeat update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + } + + // Update smoothed load via EWMA. + prevSmoothed := stats.SmoothedLoad + prevUpdate := stats.LastUpdateTime + newSmoothed := ewmaSmoothedLoad(prevSmoothed, load, prevUpdate, now) + + // Decide whether to persist this update. We always persist if this is the + // first observation (prevUpdate == 0). Otherwise, if the change is small + // and the previous persist is recent, skip the write to reduce etcd load. + shouldPersist := true + if prevUpdate > 0 { + age := now - prevUpdate + delta := math.Abs(newSmoothed - prevSmoothed) + if delta < shardStatsEpsilon && age < s.maxStatsPersistIntervalSeconds { + shouldPersist = false + } + } + + if !shouldPersist { + // Skip persisting, proceed to next shard. + continue + } + + stats.SmoothedLoad = newSmoothed + stats.LastUpdateTime = now + + payload, err := json.Marshal(stats) + if err != nil { + s.logger.Warn( + "failed to marshal shard statistics after heartbeat", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + + _, err = s.client.Put(ctx, shardStatsKey, string(payload)) + if err != nil { + s.logger.Warn( + "failed to persist shard statistics from heartbeat", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + } + } +} + // GetHeartbeat retrieves the last known heartbeat state for a single executor. func (s *executorStoreImpl) GetHeartbeat(ctx context.Context, namespace string, executorID string) (*store.HeartbeatState, *store.AssignedState, error) { // The prefix for all keys related to a single executor. @@ -741,3 +879,13 @@ func (s *executorStoreImpl) applyShardStatisticsUpdates(ctx context.Context, nam } } } + +func ewmaSmoothedLoad(prev, current float64, lastUpdate, now int64) float64 { + const tauSeconds = 30.0 // smaller = more responsive, larger = smoother + if lastUpdate <= 0 || tauSeconds <= 0 { + return current + } + dt := max(now-lastUpdate, 0) + alpha := 1 - math.Exp(-float64(dt)/tauSeconds) + return (1-alpha)*prev + alpha*current +} diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go b/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go index a47565768bd..9d6d5c7f7e3 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/fx/fxtest" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/sharddistributor/store" @@ -91,6 +92,154 @@ func TestRecordHeartbeat(t *testing.T) { assert.Equal(t, "value-2", string(resp.Kvs[0].Value)) } +func TestRecordHeartbeatUpdatesShardStatistics(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + executorStore := createStore(t, tc) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "executor-shard-stats" + shardID := "shard-with-load" + + initialStats := store.ShardStatistics{ + SmoothedLoad: 1.23, + LastUpdateTime: 10, + LastMoveTime: 123, + } + + shardStatsKey, err := etcdkeys.BuildShardKey(tc.EtcdPrefix, tc.Namespace, shardID, etcdkeys.ShardStatisticsKey) + require.NoError(t, err) + payload, err := json.Marshal(initialStats) + require.NoError(t, err) + _, err = tc.Client.Put(ctx, shardStatsKey, string(payload)) + require.NoError(t, err) + + nowTS := time.Now().Unix() + + req := store.HeartbeatState{ + LastHeartbeat: nowTS, + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: { + Status: types.ShardStatusREADY, + ShardLoad: 45.6, + }, + }, + } + + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, req)) + + nsState, err := executorStore.GetState(ctx, tc.Namespace) + require.NoError(t, err) + + require.Contains(t, nsState.ShardStats, shardID) + updated := nsState.ShardStats[shardID] + + assert.InDelta(t, 45.6, updated.SmoothedLoad, 1e-9) + assert.GreaterOrEqual(t, updated.LastUpdateTime, nowTS) + assert.Equal(t, initialStats.LastMoveTime, updated.LastMoveTime) +} + +func TestRecordHeartbeatSkipsShardStatisticsWithNilReport(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + executorStore := createStore(t, tc) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "executor-missing-load" + validShardID := "shard-with-valid-load" + skippedShardID := "shard-missing-load" + + nowTS := time.Now().Unix() + + req := store.HeartbeatState{ + LastHeartbeat: nowTS, + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + validShardID: { + Status: types.ShardStatusREADY, + ShardLoad: 3.21, + }, + skippedShardID: nil, + }, + } + + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, req)) + + nsState, err := executorStore.GetState(ctx, tc.Namespace) + require.NoError(t, err) + + require.Contains(t, nsState.ShardStats, validShardID) + validStats := nsState.ShardStats[validShardID] + assert.InDelta(t, 3.21, validStats.SmoothedLoad, 1e-9) + assert.Greater(t, validStats.LastUpdateTime, int64(0)) + + assert.NotContains(t, nsState.ShardStats, skippedShardID) +} + +func TestRecordHeartbeatShardStatisticsThrottlesWrites(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + tc.LeaderCfg.Process.HeartbeatTTL = 10 * time.Second + tc.LeaderCfg.Process.ShardStatsTTL = 10 * time.Second + mockTS := clock.NewMockedTimeSourceAt(time.Unix(1000, 0)) + executorStore := createStoreWithTimeSource(t, tc, mockTS) + esImpl, ok := executorStore.(*executorStoreImpl) + require.True(t, ok, "unexpected store implementation") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "executor-shard-stats-throttle" + shardID := "shard-stats-throttle" + + baseLoad := 0.40 + smallDelta := shardStatsEpsilon / 2 + intervalSeconds := esImpl.maxStatsPersistIntervalSeconds + halfIntervalSeconds := intervalSeconds / 2 + if halfIntervalSeconds == 0 { + halfIntervalSeconds = 1 + } + + // First heartbeat should always persist stats. + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{ + LastHeartbeat: mockTS.Now().Unix(), + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: {Status: types.ShardStatusREADY, ShardLoad: baseLoad}, + }, + })) + statsAfterFirst := getShardStats(ctx, t, executorStore, tc.Namespace, shardID) + require.NotNil(t, statsAfterFirst) + + // Advance time by less than the persist interval and provide a small delta: should skip the write. + mockTS.Advance(time.Duration(halfIntervalSeconds) * time.Second) + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{ + LastHeartbeat: mockTS.Now().Unix(), + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: {Status: types.ShardStatusREADY, ShardLoad: baseLoad + smallDelta}, + }, + })) + statsAfterSkip := getShardStats(ctx, t, executorStore, tc.Namespace, shardID) + require.NotNil(t, statsAfterSkip) + assert.Equal(t, statsAfterFirst.LastUpdateTime, statsAfterSkip.LastUpdateTime, "small recent deltas should not trigger a persist") + + // Advance time beyond the max persist interval, even small deltas should now persist. + mockTS.Advance(time.Duration(intervalSeconds) * time.Second) + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{ + LastHeartbeat: mockTS.Now().Unix(), + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: {Status: types.ShardStatusREADY, ShardLoad: baseLoad + smallDelta/2}, + }, + })) + statsAfterForce := getShardStats(ctx, t, executorStore, tc.Namespace, shardID) + require.NotNil(t, statsAfterForce) + assert.Greater(t, statsAfterForce.LastUpdateTime, statsAfterSkip.LastUpdateTime, "stale stats must be refreshed even if delta is small") +} + func TestGetHeartbeat(t *testing.T) { tc := testhelper.SetupStoreTestCluster(t) executorStore := createStore(t, tc) @@ -608,3 +757,27 @@ func createStore(t *testing.T, tc *testhelper.StoreTestCluster) store.Store { require.NoError(t, err) return store } + +func createStoreWithTimeSource(t *testing.T, tc *testhelper.StoreTestCluster, ts clock.TimeSource) store.Store { + t.Helper() + store, err := NewStore(ExecutorStoreParams{ + Client: tc.Client, + Cfg: tc.LeaderCfg, + Lifecycle: fxtest.NewLifecycle(t), + Logger: testlogger.New(t), + TimeSource: ts, + }) + require.NoError(t, err) + return store +} + +func getShardStats(ctx context.Context, t *testing.T, s store.Store, namespace, shardID string) *store.ShardStatistics { + t.Helper() + nsState, err := s.GetState(ctx, namespace) + require.NoError(t, err) + stats, ok := nsState.ShardStats[shardID] + if !ok { + return nil + } + return &stats +}