Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions service/sharddistributor/client/spectatorclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,50 @@ import (

//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination interface_mock.go . Spectator

type Spectators map[string]Spectator

func (s Spectators) ForNamespace(namespace string) (Spectator, error) {
spectator, ok := s[namespace]
if !ok {
return nil, fmt.Errorf("spectator not found for namespace %s", namespace)
}
return spectator, nil
}

func (s Spectators) Start(ctx context.Context) error {
for namespace, spectator := range s {
if err := spectator.Start(ctx); err != nil {
return fmt.Errorf("start spectator for namespace %s: %w", namespace, err)
}
}
return nil
}

func (s Spectators) Stop() {
for _, spectator := range s {
spectator.Stop()
}
}

func NewSpectators(params Params) (Spectators, error) {
spectators := make(Spectators)
for _, namespace := range params.Config.Namespaces {
spectator, err := NewSpectatorWithNamespace(params, namespace.Namespace)
if err != nil {
return nil, fmt.Errorf("create spectator for namespace %s: %w", namespace.Namespace, err)
}

spectators[namespace.Namespace] = spectator
}
return spectators, nil
}

type Spectator interface {
Start(ctx context.Context) error
Stop()

// GetShardOwner returns the owner of a shard. It first checks the local cache,
// and if not found, falls back to querying the shard distributor directly.
GetShardOwner(ctx context.Context, shardKey string) (string, error)
// GetShardOwner returns the owner of a shard
GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error)
}

type Params struct {
Expand Down Expand Up @@ -109,21 +146,9 @@ func createShardDistributorClient(yarpcClient sharddistributorv1.ShardDistributo
// Module creates a spectator module using auto-selection (single namespace only)
func Module() fx.Option {
return fx.Module("shard-distributor-spectator-client",
fx.Provide(NewSpectator),
fx.Invoke(func(spectator Spectator, lc fx.Lifecycle) {
lc.Append(fx.StartStopHook(spectator.Start, spectator.Stop))
}),
)
}

// ModuleWithNamespace creates a spectator module for a specific namespace
func ModuleWithNamespace(namespace string) fx.Option {
return fx.Module(fmt.Sprintf("shard-distributor-spectator-client-%s", namespace),
fx.Provide(func(params Params) (Spectator, error) {
return NewSpectatorWithNamespace(params, namespace)
}),
fx.Invoke(func(spectator Spectator, lc fx.Lifecycle) {
lc.Append(fx.StartStopHook(spectator.Start, spectator.Stop))
fx.Provide(NewSpectators),
fx.Invoke(func(spectators Spectators, lc fx.Lifecycle) {
lc.Append(fx.StartStopHook(spectators.Start, spectators.Stop))
}),
)
}
14 changes: 9 additions & 5 deletions service/sharddistributor/client/spectatorclient/clientimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func (s *spectatorImpl) watchLoop() {

// Server shutdown or network issue - recreate stream (load balancer will route to new server)
s.logger.Info("Stream ended, reconnecting", tag.ShardNamespace(s.namespace))
s.timeSource.Sleep(backoff.JitDuration(streamRetryInterval, streamRetryJitterCoeff))
}
}

Expand Down Expand Up @@ -163,10 +164,10 @@ func (s *spectatorImpl) handleResponse(response *types.WatchNamespaceStateRespon
tag.Counter(len(response.Executors)))
}

// GetShardOwner returns the executor ID for a given shard.
// GetShardOwner returns the full owner information including metadata for a given shard.
// It first waits for the initial state to be received, then checks the cache.
// If not found in cache, it falls back to querying the shard distributor directly.
func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (string, error) {
func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (*ShardOwner, error) {
// Wait for first state to be received to avoid flooding shard distributor on startup
s.firstStateWG.Wait()

Expand All @@ -176,7 +177,7 @@ func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (str
s.stateMu.RUnlock()

if owner != nil {
return owner.ExecutorID, nil
return owner, nil
}

// Cache miss - fall back to RPC call
Expand All @@ -189,8 +190,11 @@ func (s *spectatorImpl) GetShardOwner(ctx context.Context, shardKey string) (str
ShardKey: shardKey,
})
if err != nil {
return "", fmt.Errorf("get shard owner from shard distributor: %w", err)
return nil, fmt.Errorf("get shard owner from shard distributor: %w", err)
}

return response.Owner, nil
return &ShardOwner{
ExecutorID: response.Owner,
Metadata: response.Metadata,
}, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ func TestWatchLoopBasicFlow(t *testing.T) {
Executors: []*types.ExecutorShardAssignment{
{
ExecutorID: "executor-1",
Metadata: map[string]string{
"grpc_address": "127.0.0.1:7953",
},
AssignedShards: []*types.Shard{
{ShardKey: "shard-1"},
{ShardKey: "shard-2"},
Expand Down Expand Up @@ -72,11 +75,12 @@ func TestWatchLoopBasicFlow(t *testing.T) {
// Query shard owner
owner, err := spectator.GetShardOwner(context.Background(), "shard-1")
assert.NoError(t, err)
assert.Equal(t, "executor-1", owner)
assert.Equal(t, "executor-1", owner.ExecutorID)
assert.Equal(t, "127.0.0.1:7953", owner.Metadata["grpc_address"])

owner, err = spectator.GetShardOwner(context.Background(), "shard-2")
assert.NoError(t, err)
assert.Equal(t, "executor-1", owner)
assert.Equal(t, "executor-1", owner.ExecutorID)
}

func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) {
Expand All @@ -103,7 +107,13 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) {
// First Recv returns state
mockStream.EXPECT().Recv().Return(&types.WatchNamespaceStateResponse{
Executors: []*types.ExecutorShardAssignment{
{ExecutorID: "executor-1", AssignedShards: []*types.Shard{{ShardKey: "shard-1"}}},
{
ExecutorID: "executor-1",
Metadata: map[string]string{
"grpc_address": "127.0.0.1:7953",
},
AssignedShards: []*types.Shard{{ShardKey: "shard-1"}},
},
},
}, nil)

Expand All @@ -122,7 +132,12 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) {
Namespace: "test-ns",
ShardKey: "unknown-shard",
}).
Return(&types.GetShardOwnerResponse{Owner: "executor-2"}, nil)
Return(&types.GetShardOwnerResponse{
Owner: "executor-2",
Metadata: map[string]string{
"grpc_address": "127.0.0.1:7954",
},
}, nil)

spectator.Start(context.Background())
defer spectator.Stop()
Expand All @@ -132,12 +147,13 @@ func TestGetShardOwner_CacheMiss_FallbackToRPC(t *testing.T) {
// Cache hit
owner, err := spectator.GetShardOwner(context.Background(), "shard-1")
assert.NoError(t, err)
assert.Equal(t, "executor-1", owner)
assert.Equal(t, "executor-1", owner.ExecutorID)

// Cache miss - should trigger RPC
owner, err = spectator.GetShardOwner(context.Background(), "unknown-shard")
assert.NoError(t, err)
assert.Equal(t, "executor-2", owner)
assert.Equal(t, "executor-2", owner.ExecutorID)
assert.Equal(t, "127.0.0.1:7954", owner.Metadata["grpc_address"])
}

func TestStreamReconnection(t *testing.T) {
Expand Down Expand Up @@ -188,7 +204,9 @@ func TestStreamReconnection(t *testing.T) {
spectator.Start(context.Background())
defer spectator.Stop()

// Advance time for retry
// Wait for the goroutine to be blocked in Sleep, then advance time
mockTimeSource.BlockUntil(1) // Wait for 1 goroutine to be blocked in Sleep
mockTimeSource.Advance(2 * time.Second)

spectator.firstStateWG.Wait()
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading