diff --git a/service/sharddistributor/client/spectatorclient/client.go b/service/sharddistributor/client/spectatorclient/client.go index c4b19eb094a..8f6e2f1e8b3 100644 --- a/service/sharddistributor/client/spectatorclient/client.go +++ b/service/sharddistributor/client/spectatorclient/client.go @@ -20,13 +20,52 @@ import ( //go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination interface_mock.go . Spectator +type Spectators struct { + spectators map[string]Spectator +} + +func (s *Spectators) ForNamespace(namespace string) (Spectator, error) { + spectator, ok := s.spectators[namespace] + if !ok { + return nil, fmt.Errorf("spectator not found for namespace %s", namespace) + } + return spectator, nil +} + +func (s *Spectators) Start(ctx context.Context) error { + for namespace, spectator := range s.spectators { + if err := spectator.Start(ctx); err != nil { + return fmt.Errorf("start spectator for namespace %s: %w", namespace, err) + } + } + return nil +} + +func (s *Spectators) Stop() { + for _, spectator := range s.spectators { + spectator.Stop() + } +} + +func NewSpectators(params Params) (*Spectators, error) { + spectators := make(map[string]Spectator) + for _, namespace := range params.Config.Namespaces { + spectator, err := NewSpectatorWithNamespace(params, namespace.Namespace) + if err != nil { + return nil, fmt.Errorf("create spectator for namespace %s: %w", namespace.Namespace, err) + } + + spectators[namespace.Namespace] = spectator + } + return &Spectators{spectators: spectators}, nil +} + type Spectator interface { Start(ctx context.Context) error Stop() - // GetShardOwner returns the owner of a shard. It first checks the local cache, - // and if not found, falls back to querying the shard distributor directly. - GetShardOwner(ctx context.Context, shardKey string) (string, error) + // GetShardOwner returns the owner of a shard + GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) } type Params struct { @@ -109,21 +148,9 @@ func createShardDistributorClient(yarpcClient sharddistributorv1.ShardDistributo // Module creates a spectator module using auto-selection (single namespace only) func Module() fx.Option { return fx.Module("shard-distributor-spectator-client", - fx.Provide(NewSpectator), - fx.Invoke(func(spectator Spectator, lc fx.Lifecycle) { - lc.Append(fx.StartStopHook(spectator.Start, spectator.Stop)) - }), - ) -} - -// ModuleWithNamespace creates a spectator module for a specific namespace -func ModuleWithNamespace(namespace string) fx.Option { - return fx.Module(fmt.Sprintf("shard-distributor-spectator-client-%s", namespace), - fx.Provide(func(params Params) (Spectator, error) { - return NewSpectatorWithNamespace(params, namespace) - }), - fx.Invoke(func(spectator Spectator, lc fx.Lifecycle) { - lc.Append(fx.StartStopHook(spectator.Start, spectator.Stop)) + fx.Provide(NewSpectators), + fx.Invoke(func(spectators *Spectators, lc fx.Lifecycle) { + lc.Append(fx.StartStopHook(spectators.Start, spectators.Stop)) }), ) } diff --git a/service/sharddistributor/client/spectatorclient/clientimpl.go b/service/sharddistributor/client/spectatorclient/clientimpl.go index 88934cdadf6..15c6d32d8b5 100644 --- a/service/sharddistributor/client/spectatorclient/clientimpl.go +++ b/service/sharddistributor/client/spectatorclient/clientimpl.go @@ -103,6 +103,7 @@ func (s *spectatorImpl) watchLoop() { // Server shutdown or network issue - recreate stream (load balancer will route to new server) s.logger.Info("Stream ended, reconnecting", tag.ShardNamespace(s.namespace)) + s.timeSource.Sleep(backoff.JitDuration(streamRetryInterval, streamRetryJitterCoeff)) } } @@ -163,10 +164,10 @@ func (s *spectatorImpl) handleResponse(response *types.WatchNamespaceStateRespon tag.Counter(len(response.Executors))) } -// GetShardOwner returns the executor ID for a given shard. +// GetShardOwner returns the full owner information including metadata for a given shard. // It first waits for the initial state to be received, then checks the cache. // If not found in cache, it falls back to querying the shard distributor directly. -func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (string, error) { +func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) { // Wait for first state to be received to avoid flooding shard distributor on startup s.firstStateWG.Wait() @@ -176,7 +177,7 @@ func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (str s.stateMu.RUnlock() if owner != nil { - return owner.ExecutorID, nil + return owner, nil } // Cache miss - fall back to RPC call @@ -189,8 +190,11 @@ func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (str ShardKey: shardKey, }) if err != nil { - return "", fmt.Errorf("get shard owner from shard distributor: %w", err) + return nil, fmt.Errorf("get shard owner from shard distributor: %w", err) } - return response.Owner, nil + return &ShardOwner{ + ExecutorID: response.Owner, + Metadata: response.Metadata, + }, nil } diff --git a/service/sharddistributor/client/spectatorclient/clientimpl_test.go b/service/sharddistributor/client/spectatorclient/clientimpl_test.go index a0c60aa2706..637883b1647 100644 --- a/service/sharddistributor/client/spectatorclient/clientimpl_test.go +++ b/service/sharddistributor/client/spectatorclient/clientimpl_test.go @@ -44,6 +44,9 @@ func TestWatchLoopBasicFlow(t *testing.T) { Executors: []*types.ExecutorShardAssignment{ { ExecutorID: "executor-1", + Metadata: map[string]string{ + "grpc_address": "127.0.0.1:7953", + }, AssignedShards: []*types.Shard{ {ShardKey: "shard-1"}, {ShardKey: "shard-2"}, @@ -72,11 +75,12 @@ func TestWatchLoopBasicFlow(t *testing.T) { // Query shard owner owner, err := spectator.GetShardOwner(context.Background(), "shard-1") assert.NoError(t, err) - assert.Equal(t, "executor-1", owner) + assert.Equal(t, "executor-1", owner.ExecutorID) + assert.Equal(t, "127.0.0.1:7953", owner.Metadata["grpc_address"]) owner, err = spectator.GetShardOwner(context.Background(), "shard-2") assert.NoError(t, err) - assert.Equal(t, "executor-1", owner) + assert.Equal(t, "executor-1", owner.ExecutorID) } func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { @@ -103,7 +107,13 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { // First Recv returns state mockStream.EXPECT().Recv().Return(&types.WatchNamespaceStateResponse{ Executors: []*types.ExecutorShardAssignment{ - {ExecutorID: "executor-1", AssignedShards: []*types.Shard{{ShardKey: "shard-1"}}}, + { + ExecutorID: "executor-1", + Metadata: map[string]string{ + "grpc_address": "127.0.0.1:7953", + }, + AssignedShards: []*types.Shard{{ShardKey: "shard-1"}}, + }, }, }, nil) @@ -122,7 +132,12 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { Namespace: "test-ns", ShardKey: "unknown-shard", }). - Return(&types.GetShardOwnerResponse{Owner: "executor-2"}, nil) + Return(&types.GetShardOwnerResponse{ + Owner: "executor-2", + Metadata: map[string]string{ + "grpc_address": "127.0.0.1:7954", + }, + }, nil) spectator.Start(context.Background()) defer spectator.Stop() @@ -132,12 +147,13 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { // Cache hit owner, err := spectator.GetShardOwner(context.Background(), "shard-1") assert.NoError(t, err) - assert.Equal(t, "executor-1", owner) + assert.Equal(t, "executor-1", owner.ExecutorID) // Cache miss - should trigger RPC owner, err = spectator.GetShardOwner(context.Background(), "unknown-shard") assert.NoError(t, err) - assert.Equal(t, "executor-2", owner) + assert.Equal(t, "executor-2", owner.ExecutorID) + assert.Equal(t, "127.0.0.1:7954", owner.Metadata["grpc_address"]) } func TestStreamReconnection(t *testing.T) { @@ -188,7 +204,9 @@ func TestStreamReconnection(t *testing.T) { spectator.Start(context.Background()) defer spectator.Stop() - // Advance time for retry + // Wait for the goroutine to be blocked in Sleep, then advance time + mockTimeSource.BlockUntil(1) // Wait for 1 goroutine to be blocked in Sleep mockTimeSource.Advance(2 * time.Second) + spectator.firstStateWG.Wait() } diff --git a/service/sharddistributor/client/spectatorclient/interface_mock.go b/service/sharddistributor/client/spectatorclient/interface_mock.go index 5b1eaaa5500..0e68d476608 100644 --- a/service/sharddistributor/client/spectatorclient/interface_mock.go +++ b/service/sharddistributor/client/spectatorclient/interface_mock.go @@ -41,10 +41,10 @@ func (m *MockSpectator) EXPECT() *MockSpectatorMockRecorder { } // GetShardOwner mocks base method. -func (m *MockSpectator) GetShardOwner(ctx context.Context, shardKey string) (string, error) { +func (m *MockSpectator) GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetShardOwner", ctx, shardKey) - ret0, _ := ret[0].(string) + ret0, _ := ret[0].(*ShardOwner) ret1, _ := ret[1].(error) return ret0, ret1 }