diff --git a/service/sharddistributor/leader/process/processor.go b/service/sharddistributor/leader/process/processor.go index 76abb573900..a6572cb5a01 100644 --- a/service/sharddistributor/leader/process/processor.go +++ b/service/sharddistributor/leader/process/processor.go @@ -2,6 +2,7 @@ package process import ( "context" + "errors" "fmt" "maps" "math/rand" @@ -214,6 +215,9 @@ func (p *namespaceProcessor) runRebalancingLoop(ctx context.Context) { err = p.rebalanceShards(ctx) } if err != nil { + if isCancelledOrDeadlineExceeded(err) { + return + } p.logger.Error("rebalance failed", tag.Error(err)) } } @@ -233,6 +237,9 @@ func (p *namespaceProcessor) runShardStatsCleanupLoop(ctx context.Context) { p.logger.Info("Periodic shard stats cleanup triggered.") namespaceState, err := p.shardStore.GetState(ctx, p.namespaceCfg.Name) if err != nil { + if isCancelledOrDeadlineExceeded(err) { + return + } p.logger.Error("Failed to get state for shard stats cleanup", tag.Error(err)) continue } @@ -242,6 +249,9 @@ func (p *namespaceProcessor) runShardStatsCleanupLoop(ctx context.Context) { continue } if err := p.shardStore.DeleteShardStats(ctx, p.namespaceCfg.Name, staleShardStats, p.election.Guard()); err != nil { + if isCancelledOrDeadlineExceeded(err) { + return + } p.logger.Error("Failed to delete stale shard stats", tag.Error(err)) } } @@ -340,6 +350,9 @@ func (p *namespaceProcessor) rebalanceShardsImpl(ctx context.Context, metricsLoo namespaceState, err := p.shardStore.GetState(ctx, p.namespaceCfg.Name) if err != nil { + if isCancelledOrDeadlineExceeded(err) { + return err + } return fmt.Errorf("get state: %w", err) } @@ -386,6 +399,9 @@ func (p *namespaceProcessor) rebalanceShardsImpl(ctx context.Context, metricsLoo ExecutorsToDelete: staleExecutors, }, p.election.Guard()) if err != nil { + if isCancelledOrDeadlineExceeded(err) { + return err + } return fmt.Errorf("assign shards: %w", err) } @@ -586,3 +602,8 @@ func makeShards(num int64) []string { } return shards } + +func isCancelledOrDeadlineExceeded(err error) bool { + return errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) +} diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore.go b/service/sharddistributor/store/etcd/executorstore/etcdstore.go index 4c4a2f7462f..5b5add49415 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore.go @@ -233,8 +233,14 @@ func (s *executorStoreImpl) GetState(ctx context.Context, namespace string) (*st executorPrefix := etcdkeys.BuildExecutorsPrefix(s.prefix, namespace) resp, err := s.client.Get(ctx, executorPrefix, clientv3.WithPrefix()) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, ctx.Err() + } return nil, fmt.Errorf("get executor data: %w", err) } + if ctxErr := ctx.Err(); errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { + return nil, ctxErr + } for _, kv := range resp.Kvs { key := string(kv.Key) @@ -276,8 +282,14 @@ func (s *executorStoreImpl) GetState(ctx context.Context, namespace string) (*st shardsPrefix := etcdkeys.BuildShardsPrefix(s.prefix, namespace) shardResp, err := s.client.Get(ctx, shardsPrefix, clientv3.WithPrefix()) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, ctx.Err() + } return nil, fmt.Errorf("get shard data: %w", err) } + if ctxErr := ctx.Err(); errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { + return nil, ctxErr + } for _, kv := range shardResp.Kvs { shardID, shardKeyType, err := etcdkeys.ParseShardKey(s.prefix, namespace, string(kv.Key)) if err != nil {