diff --git a/service/sharddistributor/client/spectatorclient/client.go b/service/sharddistributor/client/spectatorclient/client.go index c4b19eb094a..8f6e2f1e8b3 100644 --- a/service/sharddistributor/client/spectatorclient/client.go +++ b/service/sharddistributor/client/spectatorclient/client.go @@ -20,13 +20,52 @@ import ( //go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination interface_mock.go . Spectator +type Spectators struct { + spectators map[string]Spectator +} + +func (s *Spectators) ForNamespace(namespace string) (Spectator, error) { + spectator, ok := s.spectators[namespace] + if !ok { + return nil, fmt.Errorf("spectator not found for namespace %s", namespace) + } + return spectator, nil +} + +func (s *Spectators) Start(ctx context.Context) error { + for namespace, spectator := range s.spectators { + if err := spectator.Start(ctx); err != nil { + return fmt.Errorf("start spectator for namespace %s: %w", namespace, err) + } + } + return nil +} + +func (s *Spectators) Stop() { + for _, spectator := range s.spectators { + spectator.Stop() + } +} + +func NewSpectators(params Params) (*Spectators, error) { + spectators := make(map[string]Spectator) + for _, namespace := range params.Config.Namespaces { + spectator, err := NewSpectatorWithNamespace(params, namespace.Namespace) + if err != nil { + return nil, fmt.Errorf("create spectator for namespace %s: %w", namespace.Namespace, err) + } + + spectators[namespace.Namespace] = spectator + } + return &Spectators{spectators: spectators}, nil +} + type Spectator interface { Start(ctx context.Context) error Stop() - // GetShardOwner returns the owner of a shard. It first checks the local cache, - // and if not found, falls back to querying the shard distributor directly. - GetShardOwner(ctx context.Context, shardKey string) (string, error) + // GetShardOwner returns the owner of a shard + GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) } type Params struct { @@ -109,21 +148,9 @@ func createShardDistributorClient(yarpcClient sharddistributorv1.ShardDistributo // Module creates a spectator module using auto-selection (single namespace only) func Module() fx.Option { return fx.Module("shard-distributor-spectator-client", - fx.Provide(NewSpectator), - fx.Invoke(func(spectator Spectator, lc fx.Lifecycle) { - lc.Append(fx.StartStopHook(spectator.Start, spectator.Stop)) - }), - ) -} - -// ModuleWithNamespace creates a spectator module for a specific namespace -func ModuleWithNamespace(namespace string) fx.Option { - return fx.Module(fmt.Sprintf("shard-distributor-spectator-client-%s", namespace), - fx.Provide(func(params Params) (Spectator, error) { - return NewSpectatorWithNamespace(params, namespace) - }), - fx.Invoke(func(spectator Spectator, lc fx.Lifecycle) { - lc.Append(fx.StartStopHook(spectator.Start, spectator.Stop)) + fx.Provide(NewSpectators), + fx.Invoke(func(spectators *Spectators, lc fx.Lifecycle) { + lc.Append(fx.StartStopHook(spectators.Start, spectators.Stop)) }), ) } diff --git a/service/sharddistributor/client/spectatorclient/clientimpl.go b/service/sharddistributor/client/spectatorclient/clientimpl.go index 88934cdadf6..15c6d32d8b5 100644 --- a/service/sharddistributor/client/spectatorclient/clientimpl.go +++ b/service/sharddistributor/client/spectatorclient/clientimpl.go @@ -103,6 +103,7 @@ func (s *spectatorImpl) watchLoop() { // Server shutdown or network issue - recreate stream (load balancer will route to new server) s.logger.Info("Stream ended, reconnecting", tag.ShardNamespace(s.namespace)) + s.timeSource.Sleep(backoff.JitDuration(streamRetryInterval, streamRetryJitterCoeff)) } } @@ -163,10 +164,10 @@ func (s *spectatorImpl) handleResponse(response *types.WatchNamespaceStateRespon tag.Counter(len(response.Executors))) } -// GetShardOwner returns the executor ID for a given shard. +// GetShardOwner returns the full owner information including metadata for a given shard. // It first waits for the initial state to be received, then checks the cache. // If not found in cache, it falls back to querying the shard distributor directly. -func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (string, error) { +func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) { // Wait for first state to be received to avoid flooding shard distributor on startup s.firstStateWG.Wait() @@ -176,7 +177,7 @@ func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (str s.stateMu.RUnlock() if owner != nil { - return owner.ExecutorID, nil + return owner, nil } // Cache miss - fall back to RPC call @@ -189,8 +190,11 @@ func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (str ShardKey: shardKey, }) if err != nil { - return "", fmt.Errorf("get shard owner from shard distributor: %w", err) + return nil, fmt.Errorf("get shard owner from shard distributor: %w", err) } - return response.Owner, nil + return &ShardOwner{ + ExecutorID: response.Owner, + Metadata: response.Metadata, + }, nil } diff --git a/service/sharddistributor/client/spectatorclient/clientimpl_test.go b/service/sharddistributor/client/spectatorclient/clientimpl_test.go index a0c60aa2706..637883b1647 100644 --- a/service/sharddistributor/client/spectatorclient/clientimpl_test.go +++ b/service/sharddistributor/client/spectatorclient/clientimpl_test.go @@ -44,6 +44,9 @@ func TestWatchLoopBasicFlow(t *testing.T) { Executors: []*types.ExecutorShardAssignment{ { ExecutorID: "executor-1", + Metadata: map[string]string{ + "grpc_address": "127.0.0.1:7953", + }, AssignedShards: []*types.Shard{ {ShardKey: "shard-1"}, {ShardKey: "shard-2"}, @@ -72,11 +75,12 @@ func TestWatchLoopBasicFlow(t *testing.T) { // Query shard owner owner, err := spectator.GetShardOwner(context.Background(), "shard-1") assert.NoError(t, err) - assert.Equal(t, "executor-1", owner) + assert.Equal(t, "executor-1", owner.ExecutorID) + assert.Equal(t, "127.0.0.1:7953", owner.Metadata["grpc_address"]) owner, err = spectator.GetShardOwner(context.Background(), "shard-2") assert.NoError(t, err) - assert.Equal(t, "executor-1", owner) + assert.Equal(t, "executor-1", owner.ExecutorID) } func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { @@ -103,7 +107,13 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { // First Recv returns state mockStream.EXPECT().Recv().Return(&types.WatchNamespaceStateResponse{ Executors: []*types.ExecutorShardAssignment{ - {ExecutorID: "executor-1", AssignedShards: []*types.Shard{{ShardKey: "shard-1"}}}, + { + ExecutorID: "executor-1", + Metadata: map[string]string{ + "grpc_address": "127.0.0.1:7953", + }, + AssignedShards: []*types.Shard{{ShardKey: "shard-1"}}, + }, }, }, nil) @@ -122,7 +132,12 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { Namespace: "test-ns", ShardKey: "unknown-shard", }). - Return(&types.GetShardOwnerResponse{Owner: "executor-2"}, nil) + Return(&types.GetShardOwnerResponse{ + Owner: "executor-2", + Metadata: map[string]string{ + "grpc_address": "127.0.0.1:7954", + }, + }, nil) spectator.Start(context.Background()) defer spectator.Stop() @@ -132,12 +147,13 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) { // Cache hit owner, err := spectator.GetShardOwner(context.Background(), "shard-1") assert.NoError(t, err) - assert.Equal(t, "executor-1", owner) + assert.Equal(t, "executor-1", owner.ExecutorID) // Cache miss - should trigger RPC owner, err = spectator.GetShardOwner(context.Background(), "unknown-shard") assert.NoError(t, err) - assert.Equal(t, "executor-2", owner) + assert.Equal(t, "executor-2", owner.ExecutorID) + assert.Equal(t, "127.0.0.1:7954", owner.Metadata["grpc_address"]) } func TestStreamReconnection(t *testing.T) { @@ -188,7 +204,9 @@ func TestStreamReconnection(t *testing.T) { spectator.Start(context.Background()) defer spectator.Stop() - // Advance time for retry + // Wait for the goroutine to be blocked in Sleep, then advance time + mockTimeSource.BlockUntil(1) // Wait for 1 goroutine to be blocked in Sleep mockTimeSource.Advance(2 * time.Second) + spectator.firstStateWG.Wait() } diff --git a/service/sharddistributor/client/spectatorclient/interface_mock.go b/service/sharddistributor/client/spectatorclient/interface_mock.go index 5b1eaaa5500..0e68d476608 100644 --- a/service/sharddistributor/client/spectatorclient/interface_mock.go +++ b/service/sharddistributor/client/spectatorclient/interface_mock.go @@ -41,10 +41,10 @@ func (m *MockSpectator) EXPECT() *MockSpectatorMockRecorder { } // GetShardOwner mocks base method. -func (m *MockSpectator) GetShardOwner(ctx context.Context, shardKey string) (string, error) { +func (m *MockSpectator) GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetShardOwner", ctx, shardKey) - ret0, _ := ret[0].(string) + ret0, _ := ret[0].(*ShardOwner) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/service/sharddistributor/client/spectatorclient/peer_chooser.go b/service/sharddistributor/client/spectatorclient/peer_chooser.go new file mode 100644 index 00000000000..4deb8f60d46 --- /dev/null +++ b/service/sharddistributor/client/spectatorclient/peer_chooser.go @@ -0,0 +1,176 @@ +package spectatorclient + +import ( + "context" + "fmt" + "sync" + + "go.uber.org/fx" + "go.uber.org/yarpc/api/peer" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/peer/hostport" + "go.uber.org/yarpc/yarpcerrors" + + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" +) + +const ( + NamespaceHeader = "x-shard-distributor-namespace" + grpcAddressMetadataKey = "grpc_address" +) + +// SpectatorPeerChooserInterface extends peer.Chooser with SetSpectators method +type SpectatorPeerChooserInterface interface { + peer.Chooser + SetSpectators(spectators *Spectators) +} + +// SpectatorPeerChooser is a peer.Chooser that uses the Spectator to route requests +// to the correct executor based on shard ownership. +// This is the shard distributor equivalent of Cadence's RingpopPeerChooser. +// +// Flow: +// 1. Client calls RPC with yarpc.WithShardKey("shard-key") +// 2. Choose() is called with req.ShardKey = "shard-key" +// 3. Query Spectator for shard owner +// 4. Extract grpc_address from owner metadata +// 5. Create/reuse peer for that address +// 6. Return peer to YARPC for connection +type SpectatorPeerChooser struct { + spectators *Spectators + transport peer.Transport + logger log.Logger + namespace string + + mu sync.RWMutex + peers map[string]peer.Peer // grpc_address -> peer +} + +type SpectatorPeerChooserParams struct { + fx.In + Transport peer.Transport + Logger log.Logger +} + +// NewSpectatorPeerChooser creates a new peer chooser that routes based on shard distributor ownership +func NewSpectatorPeerChooser( + params SpectatorPeerChooserParams, +) SpectatorPeerChooserInterface { + return &SpectatorPeerChooser{ + transport: params.Transport, + logger: params.Logger, + peers: make(map[string]peer.Peer), + } +} + +// Start satisfies the peer.Chooser interface +func (c *SpectatorPeerChooser) Start() error { + c.logger.Info("Starting shard distributor peer chooser", tag.ShardNamespace(c.namespace)) + return nil +} + +// Stop satisfies the peer.Chooser interface +func (c *SpectatorPeerChooser) Stop() error { + c.logger.Info("Stopping shard distributor peer chooser", tag.ShardNamespace(c.namespace)) + + // Release all peers + c.mu.Lock() + defer c.mu.Unlock() + + for addr, p := range c.peers { + if err := c.transport.ReleasePeer(p, &noOpSubscriber{}); err != nil { + c.logger.Error("Failed to release peer", tag.Error(err), tag.Address(addr)) + } + } + c.peers = make(map[string]peer.Peer) + + return nil +} + +// IsRunning satisfies the peer.Chooser interface +func (c *SpectatorPeerChooser) IsRunning() bool { + return true +} + +// Choose returns a peer for the given shard key by: +// 0. Looking up the spectator for the namespace using the x-shard-distributor-namespace header +// 1. Looking up the shard owner via the Spectator +// 2. Extracting the grpc_address from the owner's metadata +// 3. Creating/reusing a peer for that address +// +// The ShardKey in the request is the actual shard key (e.g., workflow ID, shard ID), +// NOT the ip:port address. This is the key distinction from directPeerChooser. +func (c *SpectatorPeerChooser) Choose(ctx context.Context, req *transport.Request) (peer.Peer, func(error), error) { + if req.ShardKey == "" { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("chooser requires ShardKey to be non-empty") + } + + // Get the spectator for the namespace + namespace, ok := req.Headers.Get(NamespaceHeader) + if !ok || namespace == "" { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("chooser requires x-shard-distributor-namespace header to be non-empty") + } + + spectator, err := c.spectators.ForNamespace(namespace) + if err != nil { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("failed to get spectator for namespace %s: %w", namespace, err) + } + + // Query spectator for shard owner + owner, err := spectator.GetShardOwner(ctx, req.ShardKey) + if err != nil { + return nil, nil, yarpcerrors.UnavailableErrorf("failed to get shard owner for key %s: %v", req.ShardKey, err) + } + + // Extract GRPC address from owner metadata + grpcAddress, ok := owner.Metadata[grpcAddressMetadataKey] + if !ok || grpcAddress == "" { + return nil, nil, yarpcerrors.InternalErrorf("no grpc_address in metadata for executor %s owning shard %s", owner.ExecutorID, req.ShardKey) + } + + // Check if we already have a peer for this address + c.mu.RLock() + p, ok := c.peers[grpcAddress] + if ok { + c.mu.RUnlock() + return p, func(error) {}, nil + } + c.mu.RUnlock() + + // Create new peer for this address + p, err = c.addPeer(grpcAddress) + if err != nil { + return nil, nil, yarpcerrors.InternalErrorf("failed to add peer for address %s: %v", grpcAddress, err) + } + + return p, func(error) {}, nil +} + +func (c *SpectatorPeerChooser) SetSpectators(spectators *Spectators) { + c.spectators = spectators +} + +func (c *SpectatorPeerChooser) addPeer(grpcAddress string) (peer.Peer, error) { + c.mu.Lock() + defer c.mu.Unlock() + + // Check again in case another goroutine added it + if p, ok := c.peers[grpcAddress]; ok { + return p, nil + } + + p, err := c.transport.RetainPeer(hostport.Identify(grpcAddress), &noOpSubscriber{}) + if err != nil { + return nil, fmt.Errorf("retain peer failed: %w", err) + } + + c.peers[grpcAddress] = p + c.logger.Info("Added peer to shard distributor peer chooser", tag.Address(grpcAddress)) + return p, nil +} + +// noOpSubscriber is a no-op implementation of peer.Subscriber +type noOpSubscriber struct{} + +func (*noOpSubscriber) NotifyStatusChanged(peer.Identifier) {} diff --git a/service/sharddistributor/client/spectatorclient/peer_chooser_test.go b/service/sharddistributor/client/spectatorclient/peer_chooser_test.go new file mode 100644 index 00000000000..569c354d01f --- /dev/null +++ b/service/sharddistributor/client/spectatorclient/peer_chooser_test.go @@ -0,0 +1,189 @@ +package spectatorclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/yarpc/api/peer" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/transport/grpc" + + "github.com/uber/cadence/common/log/testlogger" +) + +func TestSpectatorPeerChooser_Choose_MissingShardKey(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + } + + req := &transport.Request{ + ShardKey: "", + Headers: transport.NewHeaders(), + } + + p, onFinish, err := chooser.Choose(context.Background(), req) + + assert.Error(t, err) + assert.Nil(t, p) + assert.Nil(t, onFinish) + assert.Contains(t, err.Error(), "ShardKey") +} + +func TestSpectatorPeerChooser_Choose_MissingNamespaceHeader(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + } + + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders(), + } + + p, onFinish, err := chooser.Choose(context.Background(), req) + + assert.Error(t, err) + assert.Nil(t, p) + assert.Nil(t, onFinish) + assert.Contains(t, err.Error(), "x-shard-distributor-namespace") +} + +func TestSpectatorPeerChooser_Choose_SpectatorNotFound(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + spectators: &Spectators{spectators: make(map[string]Spectator)}, + } + + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders().With(NamespaceHeader, "unknown-namespace"), + } + + p, onFinish, err := chooser.Choose(context.Background(), req) + + assert.Error(t, err) + assert.Nil(t, p) + assert.Nil(t, onFinish) + assert.Contains(t, err.Error(), "spectator not found") +} + +func TestSpectatorPeerChooser_StartStop(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + } + + err := chooser.Start() + require.NoError(t, err) + + assert.True(t, chooser.IsRunning()) + + err = chooser.Stop() + assert.NoError(t, err) +} + +func TestSpectatorPeerChooser_SetSpectators(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + } + + spectators := &Spectators{spectators: make(map[string]Spectator)} + chooser.SetSpectators(spectators) + + assert.Equal(t, spectators, chooser.spectators) +} + +func TestSpectatorPeerChooser_Choose_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSpectator := NewMockSpectator(ctrl) + peerTransport := grpc.NewTransport() + + chooser := &SpectatorPeerChooser{ + transport: peerTransport, + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + spectators: &Spectators{ + spectators: map[string]Spectator{ + "test-namespace": mockSpectator, + }, + }, + } + + ctx := context.Background() + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders().With(NamespaceHeader, "test-namespace"), + } + + // Mock spectator to return shard owner with grpc_address + mockSpectator.EXPECT(). + GetShardOwner(ctx, "shard-1"). + Return(&ShardOwner{ + ExecutorID: "executor-1", + Metadata: map[string]string{ + grpcAddressMetadataKey: "127.0.0.1:7953", + }, + }, nil) + + // Execute + p, onFinish, err := chooser.Choose(ctx, req) + + // Assert + assert.NoError(t, err) + assert.NotNil(t, p) + assert.NotNil(t, onFinish) + assert.Equal(t, "127.0.0.1:7953", p.Identifier()) + assert.Len(t, chooser.peers, 1) +} + +func TestSpectatorPeerChooser_Choose_ReusesPeer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSpectator := NewMockSpectator(ctrl) + peerTransport := grpc.NewTransport() + + chooser := &SpectatorPeerChooser{ + transport: peerTransport, + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + spectators: &Spectators{ + spectators: map[string]Spectator{ + "test-namespace": mockSpectator, + }, + }, + } + + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders().With(NamespaceHeader, "test-namespace"), + } + + // First call creates the peer + mockSpectator.EXPECT(). + GetShardOwner(gomock.Any(), "shard-1"). + Return(&ShardOwner{ + ExecutorID: "executor-1", + Metadata: map[string]string{ + grpcAddressMetadataKey: "127.0.0.1:7953", + }, + }, nil).Times(2) + + firstPeer, _, err := chooser.Choose(context.Background(), req) + require.NoError(t, err) + + // Second call should reuse the same peer + secondPeer, _, err := chooser.Choose(context.Background(), req) + + // Assert - should reuse existing peer + assert.NoError(t, err) + assert.Equal(t, firstPeer, secondPeer) + assert.Len(t, chooser.peers, 1) +}