diff --git a/cmd/utils/bor_flags.go b/cmd/utils/bor_flags.go index f5f719f79f..3c87196dcf 100644 --- a/cmd/utils/bor_flags.go +++ b/cmd/utils/bor_flags.go @@ -16,10 +16,10 @@ var ( // Bor Specific flags // - // HeimdallURLFlag flag for heimdall url + // HeimdallURLFlag flag for heimdall url (comma-separated for failover) HeimdallURLFlag = &cli.StringFlag{ Name: "bor.heimdall", - Usage: "URL of Heimdall service", + Usage: "URL of Heimdall service (comma-separated for failover: \"url1,url2\")", Value: "http://localhost:1317", } @@ -36,17 +36,17 @@ var ( Usage: "Run without Heimdall service (for testing purpose)", } - // HeimdallgRPCAddressFlag flag for heimdall gRPC address + // HeimdallgRPCAddressFlag flag for heimdall gRPC address (comma-separated for failover) HeimdallgRPCAddressFlag = &cli.StringFlag{ Name: "bor.heimdallgRPC", - Usage: "Address of Heimdall gRPC service", + Usage: "Address of Heimdall gRPC service (comma-separated for failover: \"addr1,addr2\")", Value: "", } - // HeimdallWSAddressFlag flag for heimdall websocket subscription service + // HeimdallWSAddressFlag flag for heimdall websocket subscription service (comma-separated for failover) HeimdallWSAddressFlag = &cli.StringFlag{ Name: "bor.heimdallWS", - Usage: "Address of Heimdall WS Subscription service", + Usage: "Address of Heimdall WS Subscription service (comma-separated for failover: \"addr1,addr2\")", Value: "", } diff --git a/consensus/bor/heimdall/client.go b/consensus/bor/heimdall/client.go index a27aa4f6aa..d8a4878d83 100644 --- a/consensus/bor/heimdall/client.go +++ b/consensus/bor/heimdall/client.go @@ -39,6 +39,20 @@ var ( ErrServiceUnavailable = errors.New("service unavailable") ) +// HTTPStatusError is returned when Heimdall responds with a non-2xx, non-503 status code. +// It wraps ErrNotSuccessfulResponse for backwards-compatibility with errors.Is checks. +type HTTPStatusError struct { + StatusCode int +} + +func (e *HTTPStatusError) Error() string { + return fmt.Sprintf("%s: response code %d", ErrNotSuccessfulResponse.Error(), e.StatusCode) +} + +func (e *HTTPStatusError) Unwrap() error { + return ErrNotSuccessfulResponse +} + const ( heimdallAPIBodyLimit = 128 * 1024 * 1024 // 128 MB stateFetchLimit = 50 @@ -455,7 +469,7 @@ func internalFetch(ctx context.Context, client http.Client, u *url.URL) ([]byte, // check status code if res.StatusCode != 200 && res.StatusCode != 204 { - return nil, fmt.Errorf("%w: response code %d", ErrNotSuccessfulResponse, res.StatusCode) + return nil, &HTTPStatusError{StatusCode: res.StatusCode} } // unmarshall data from buffer diff --git a/consensus/bor/heimdall/failover_client.go b/consensus/bor/heimdall/failover_client.go new file mode 100644 index 0000000000..9b20269ff2 --- /dev/null +++ b/consensus/bor/heimdall/failover_client.go @@ -0,0 +1,312 @@ +package heimdall + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/0xPolygon/heimdall-v2/x/bor/types" + ctypes "github.com/cometbft/cometbft/rpc/core/types" + + "github.com/ethereum/go-ethereum/consensus/bor/clerk" + "github.com/ethereum/go-ethereum/consensus/bor/heimdall/checkpoint" + "github.com/ethereum/go-ethereum/consensus/bor/heimdall/milestone" + "github.com/ethereum/go-ethereum/log" +) + +const ( + defaultAttemptTimeout = 30 * time.Second + defaultProbeTimeout = 5 * time.Second + defaultHealthCheckInterval = 10 * time.Second + defaultConsecutiveThreshold = 3 + defaultPromotionCooldown = 60 * time.Second +) + +// Endpoint matches bor.IHeimdallClient. It is exported so that external +// packages can build []Endpoint slices for NewMultiHeimdallClient without +// running into Go's covariant-slice restriction. +type Endpoint interface { + StateSyncEvents(ctx context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error) + GetSpan(ctx context.Context, spanID uint64) (*types.Span, error) + GetLatestSpan(ctx context.Context) (*types.Span, error) + FetchCheckpoint(ctx context.Context, number int64) (*checkpoint.Checkpoint, error) + FetchCheckpointCount(ctx context.Context) (int64, error) + FetchMilestone(ctx context.Context) (*milestone.Milestone, error) + FetchMilestoneCount(ctx context.Context) (int64, error) + FetchStatus(ctx context.Context) (*ctypes.SyncInfo, error) + Close() +} + +// MultiHeimdallClient wraps N heimdall clients (primary at index 0, failovers +// at 1..N-1) and transparently cascades through them when the active client is +// unreachable. A background health registry continuously probes ALL endpoints, +// requires consecutive successes + cooldown before promotion, and gives cascade +// full visibility into endpoint health. +type MultiHeimdallClient struct { + clients []Endpoint + registry *HealthRegistry + attemptTimeout time.Duration + probeTimeout time.Duration + probeCtx context.Context // cancelled on Close to abort in-flight probes + probeCancel context.CancelFunc +} + +func NewMultiHeimdallClient(clients ...Endpoint) (*MultiHeimdallClient, error) { + if len(clients) == 0 { + return nil, fmt.Errorf("NewMultiHeimdallClient requires at least one client") + } + + probeCtx, probeCancel := context.WithCancel(context.Background()) + + f := &MultiHeimdallClient{ + clients: clients, + attemptTimeout: defaultAttemptTimeout, + probeTimeout: defaultProbeTimeout, + probeCtx: probeCtx, + probeCancel: probeCancel, + } + + f.registry = NewHealthRegistry( + len(clients), + f.probeEndpoint, + nil, // HTTP client doesn't need onSwitch callback + RegistryMetrics{ + ProbeAttempts: failoverProbeAttempts, + ProbeSuccesses: failoverProbeSuccesses, + ProactiveSwitches: failoverProactiveSwitches, + ActiveGauge: failoverActiveGauge, + HealthyEndpoints: failoverHealthyEndpoints, + }, + ) + + return f, nil +} + +// probeEndpoint probes a single endpoint via FetchStatus. +func (f *MultiHeimdallClient) probeEndpoint(i int) error { + ctx, cancel := context.WithTimeout(f.probeCtx, f.probeTimeout) + defer cancel() + + _, err := f.clients[i].FetchStatus(ctx) + + return err +} + +// ensureHealthRegistry lazily starts the health registry goroutine on the first +// API call. This allows tests to configure fields (thresholds, intervals) after +// construction but before the goroutine reads them. +func (f *MultiHeimdallClient) ensureHealthRegistry() { + if len(f.clients) > 1 { + f.registry.Start() + } +} + +func (f *MultiHeimdallClient) StateSyncEvents(ctx context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) ([]*clerk.EventRecordWithTime, error) { + return c.StateSyncEvents(ctx, fromID, to) + }) +} + +func (f *MultiHeimdallClient) GetSpan(ctx context.Context, spanID uint64) (*types.Span, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (*types.Span, error) { + return c.GetSpan(ctx, spanID) + }) +} + +func (f *MultiHeimdallClient) GetLatestSpan(ctx context.Context) (*types.Span, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (*types.Span, error) { + return c.GetLatestSpan(ctx) + }) +} + +func (f *MultiHeimdallClient) FetchCheckpoint(ctx context.Context, number int64) (*checkpoint.Checkpoint, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (*checkpoint.Checkpoint, error) { + return c.FetchCheckpoint(ctx, number) + }) +} + +func (f *MultiHeimdallClient) FetchCheckpointCount(ctx context.Context) (int64, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (int64, error) { + return c.FetchCheckpointCount(ctx) + }) +} + +func (f *MultiHeimdallClient) FetchMilestone(ctx context.Context) (*milestone.Milestone, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (*milestone.Milestone, error) { + return c.FetchMilestone(ctx) + }) +} + +func (f *MultiHeimdallClient) FetchMilestoneCount(ctx context.Context) (int64, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (int64, error) { + return c.FetchMilestoneCount(ctx) + }) +} + +func (f *MultiHeimdallClient) FetchStatus(ctx context.Context) (*ctypes.SyncInfo, error) { + return callWithFailover(f, ctx, func(ctx context.Context, c Endpoint) (*ctypes.SyncInfo, error) { + return c.FetchStatus(ctx) + }) +} + +func (f *MultiHeimdallClient) Close() { + f.probeCancel() // cancel in-flight probes first + f.registry.Stop() + + for _, c := range f.clients { + c.Close() + } +} + +// callWithFailover executes fn against the active client. If the active client +// fails with a failover-eligible error, it marks it unhealthy and cascades +// through remaining clients using health registry information. +func callWithFailover[T any](f *MultiHeimdallClient, ctx context.Context, fn func(context.Context, Endpoint) (T, error)) (T, error) { + f.ensureHealthRegistry() + + active := f.registry.Active() + + subCtx, cancel := context.WithTimeout(ctx, f.attemptTimeout) + result, err := fn(subCtx, f.clients[active]) + cancel() + + if err == nil { + return result, nil + } + + if !isFailoverError(err, ctx) { + var zero T + return zero, err + } + + // Mark the active endpoint unhealthy in the registry. + f.registry.MarkUnhealthy(active, err) + + if active == 0 { + log.Warn("Heimdall failover: primary failed, cascading", "err", err) + } + + return cascadeClients(f, ctx, fn, active, err) +} + +// cascadeClients tries all endpoints in priority order using health registry +// information. It uses a three-pass approach: +// 1. Healthy + cooled endpoints in priority order (skipping failed active) +// 2. Healthy but NOT cooled endpoints in priority order +// 3. Unhealthy endpoints in priority order (last resort) +func cascadeClients[T any](f *MultiHeimdallClient, ctx context.Context, fn func(context.Context, Endpoint) (T, error), failed int, lastErr error) (T, error) { + n := len(f.clients) + + // Build candidate lists based on health state. + snap := f.registry.HealthSnapshot() + cooldown := f.registry.PromotionCooldown + + var cooled, uncooled, unhealthy []int + + for i := 0; i < n; i++ { + if i == failed { + continue + } + + if snap[i].Healthy { + if time.Since(snap[i].HealthySince) >= cooldown { + cooled = append(cooled, i) + } else { + uncooled = append(uncooled, i) + } + } else { + unhealthy = append(unhealthy, i) + } + } + + // Try each pass in order. + passes := [][]int{cooled, uncooled, unhealthy} + + for _, candidates := range passes { + for _, i := range candidates { + subCtx, cancel := context.WithTimeout(ctx, f.attemptTimeout) + result, err := fn(subCtx, f.clients[i]) + cancel() + + if err == nil { + f.registry.SetActive(i) + f.registry.MarkSuccess(i) + + failoverSwitchCounter.Inc(1) + + log.Warn("Heimdall failover: switched to client", "index", i) + + return result, nil + } + + lastErr = err + + if !isFailoverError(err, ctx) { + var zero T + return zero, err + } + + // Mark this endpoint unhealthy too. + f.registry.MarkUnhealthy(i, err) + } + } + + var zero T + return zero, lastErr +} + +// isFailoverError returns true if the error warrants trying the secondary. +// It distinguishes between sub-context timeouts (failover-eligible) and +// caller context cancellation (not eligible). +func isFailoverError(err error, callerCtx context.Context) bool { + if err == nil { + return false + } + + // If the caller's context is done, this is not a failover scenario + if callerCtx.Err() != nil { + return false + } + + // Shutdown detected - not a transport error + if errors.Is(err, ErrShutdownDetected) { + return false + } + + // 503 is a Heimdall feature-gate, not a transport issue + if errors.Is(err, ErrServiceUnavailable) { + return false + } + + // Transport errors + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + + // No response from Heimdall + if errors.Is(err, ErrNoResponse) { + return true + } + + // Server-side HTTP error (5xx, excluding 503 which is already handled above). + // Client errors (4xx) are logical errors; the secondary would return the same response. + var httpErr *HTTPStatusError + if errors.As(err, &httpErr) { + return httpErr.StatusCode >= 500 + } + + // Sub-context deadline exceeded (the caller's context is still alive at this point) + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + // Context canceled from sub-context (caller ctx is still alive) + if errors.Is(err, context.Canceled) { + return true + } + + return false +} diff --git a/consensus/bor/heimdall/failover_client_test.go b/consensus/bor/heimdall/failover_client_test.go new file mode 100644 index 0000000000..1ed5740ddd --- /dev/null +++ b/consensus/bor/heimdall/failover_client_test.go @@ -0,0 +1,1257 @@ +package heimdall + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/0xPolygon/heimdall-v2/x/bor/types" + ctypes "github.com/cometbft/cometbft/rpc/core/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/consensus/bor/clerk" + "github.com/ethereum/go-ethereum/consensus/bor/heimdall/checkpoint" + "github.com/ethereum/go-ethereum/consensus/bor/heimdall/milestone" +) + +// mockHeimdallClient is a configurable mock implementing the Endpoint interface. +type mockHeimdallClient struct { + getSpanFn func(ctx context.Context, spanID uint64) (*types.Span, error) + getLatestSpanFn func(ctx context.Context) (*types.Span, error) + stateSyncEventsFn func(ctx context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error) + fetchCheckpointFn func(ctx context.Context, number int64) (*checkpoint.Checkpoint, error) + fetchCheckpointCntFn func(ctx context.Context) (int64, error) + fetchMilestoneFn func(ctx context.Context) (*milestone.Milestone, error) + fetchMilestoneCntFn func(ctx context.Context) (int64, error) + fetchStatusFn func(ctx context.Context) (*ctypes.SyncInfo, error) + closeFn func() + hits atomic.Int32 +} + +func (m *mockHeimdallClient) StateSyncEvents(ctx context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error) { + m.hits.Add(1) + + if m.stateSyncEventsFn != nil { + return m.stateSyncEventsFn(ctx, fromID, to) + } + + return []*clerk.EventRecordWithTime{}, nil +} + +func (m *mockHeimdallClient) GetSpan(ctx context.Context, spanID uint64) (*types.Span, error) { + m.hits.Add(1) + + if m.getSpanFn != nil { + return m.getSpanFn(ctx, spanID) + } + + return &types.Span{Id: spanID}, nil +} + +func (m *mockHeimdallClient) GetLatestSpan(ctx context.Context) (*types.Span, error) { + m.hits.Add(1) + + if m.getLatestSpanFn != nil { + return m.getLatestSpanFn(ctx) + } + + return &types.Span{Id: 99}, nil +} + +func (m *mockHeimdallClient) FetchCheckpoint(ctx context.Context, number int64) (*checkpoint.Checkpoint, error) { + m.hits.Add(1) + + if m.fetchCheckpointFn != nil { + return m.fetchCheckpointFn(ctx, number) + } + + return &checkpoint.Checkpoint{}, nil +} + +func (m *mockHeimdallClient) FetchCheckpointCount(ctx context.Context) (int64, error) { + m.hits.Add(1) + + if m.fetchCheckpointCntFn != nil { + return m.fetchCheckpointCntFn(ctx) + } + + return 10, nil +} + +func (m *mockHeimdallClient) FetchMilestone(ctx context.Context) (*milestone.Milestone, error) { + m.hits.Add(1) + + if m.fetchMilestoneFn != nil { + return m.fetchMilestoneFn(ctx) + } + + return &milestone.Milestone{}, nil +} + +func (m *mockHeimdallClient) FetchMilestoneCount(ctx context.Context) (int64, error) { + m.hits.Add(1) + + if m.fetchMilestoneCntFn != nil { + return m.fetchMilestoneCntFn(ctx) + } + + return 5, nil +} + +func (m *mockHeimdallClient) FetchStatus(ctx context.Context) (*ctypes.SyncInfo, error) { + m.hits.Add(1) + + if m.fetchStatusFn != nil { + return m.fetchStatusFn(ctx) + } + + return &ctypes.SyncInfo{}, nil +} + +func (m *mockHeimdallClient) Close() { + if m.closeFn != nil { + m.closeFn() + } +} + +// testConnErr is a reusable connection-refused error for tests. +var testConnErr = &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + +// newConnRefusedMock creates a mock where both API calls and health probes always fail. +func newConnRefusedMock() *mockHeimdallClient { + return &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, testConnErr + }, + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + return nil, testConnErr + }, + } +} + +// newToggleMock creates a mock whose API calls and health probes fail when down.Load() is true. +func newToggleMock(down *atomic.Bool) *mockHeimdallClient { + return &mockHeimdallClient{ + getSpanFn: func(_ context.Context, spanID uint64) (*types.Span, error) { + if down.Load() { + return nil, testConnErr + } + return &types.Span{Id: spanID}, nil + }, + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + if down.Load() { + return nil, testConnErr + } + return &ctypes.SyncInfo{}, nil + }, + } +} + +// newProbeToggleMock creates a mock where API calls always fail but health probes +// succeed when down.Load() is false. +func newProbeToggleMock(down *atomic.Bool) *mockHeimdallClient { + return &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, testConnErr + }, + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + if down.Load() { + return nil, testConnErr + } + return &ctypes.SyncInfo{}, nil + }, + } +} + +// newInstantMulti creates a MultiHeimdallClient with instant health registry +// behavior: consecutiveThreshold=1, promotionCooldown=0, fast health-check interval. +func newInstantMulti(clients ...Endpoint) *MultiHeimdallClient { + fc, err := NewMultiHeimdallClient(clients...) + if err != nil { + panic(err) + } + + fc.attemptTimeout = 100 * time.Millisecond + fc.probeTimeout = 100 * time.Millisecond + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + fc.registry.HealthCheckInterval = 50 * time.Millisecond + + return fc +} + +func TestNewMultiHeimdallClient_NoClients_ReturnsError(t *testing.T) { + _, err := NewMultiHeimdallClient() + require.Error(t, err) +} + +func TestFailover_SwitchOnPrimaryDown(t *testing.T) { + switchesBefore := failoverSwitchCounter.Snapshot().Count() + activeBefore := failoverActiveGauge.Snapshot().Value() + + primary := &mockHeimdallClient{ + getSpanFn: func(ctx context.Context, _ uint64) (*types.Span, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + span, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + require.NotNil(t, span) + + assert.GreaterOrEqual(t, primary.hits.Load(), int32(1), "primary should have been tried") + assert.GreaterOrEqual(t, secondary.hits.Load(), int32(1), "secondary should have been called") + + assert.Greater(t, failoverSwitchCounter.Snapshot().Count(), switchesBefore, "failover switch counter should increment") + _ = activeBefore // gauge is set, not incremented + assert.Equal(t, int64(1), failoverActiveGauge.Snapshot().Value(), "active gauge should reflect secondary index") +} + +func TestFailover_NoSwitchOnContextCanceled(t *testing.T) { + primary := &mockHeimdallClient{ + getSpanFn: func(ctx context.Context, _ uint64) (*types.Span, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 5 * time.Second // longer than caller's ctx + fc.probeTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Start registry and let the immediate probe cycle complete so its + // FetchStatus hits don't race with the assertion below. + fc.ensureHealthRegistry() + time.Sleep(50 * time.Millisecond) + + secondaryBefore := secondary.hits.Load() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = fc.GetSpan(ctx, 1) + require.Error(t, err) + assert.Equal(t, secondaryBefore, secondary.hits.Load(), "should not failover on caller context cancellation") +} + +func TestFailover_NoSwitchOnServiceUnavailable(t *testing.T) { + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, ErrServiceUnavailable + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + _, err = fc.GetSpan(context.Background(), 1) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrServiceUnavailable)) + assert.Equal(t, int32(0), secondary.hits.Load(), "should not failover on 503") +} + +func TestFailover_NoSwitchOnShutdownDetected(t *testing.T) { + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, ErrShutdownDetected + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + _, err = fc.GetSpan(context.Background(), 1) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrShutdownDetected)) + assert.Equal(t, int32(0), secondary.hits.Load(), "should not failover on shutdown") +} + +func TestFailover_StickyBehavior(t *testing.T) { + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.probeTimeout = 100 * time.Millisecond + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + fc.registry.HealthCheckInterval = 1 * time.Hour // very long — no background promotion + defer fc.Close() + + // First call triggers failover + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Wait for the immediate probe cycle (launched by ensureHealthRegistry + // inside the first GetSpan call) to complete before snapshotting hits. + time.Sleep(50 * time.Millisecond) + + primaryBefore := primary.hits.Load() + secondaryBefore := secondary.hits.Load() + + // Subsequent calls should go directly to secondary without trying primary + for i := 0; i < 3; i++ { + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + } + + assert.Equal(t, primaryBefore, primary.hits.Load(), "primary should not be contacted while sticky") + assert.Equal(t, secondaryBefore+3, secondary.hits.Load(), "all calls should go to secondary") +} + +func TestFailover_ProbeBackToPrimary(t *testing.T) { + primaryDown := atomic.Bool{} + primaryDown.Store(true) + + primary := newToggleMock(&primaryDown) + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + // Trigger failover + _, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Bring primary back + primaryDown.Store(false) + + // Wait for background health registry to promote primary + require.Eventually(t, func() bool { + return fc.registry.Active() == 0 + }, 2*time.Second, 20*time.Millisecond, "health registry should promote back to primary") + + // Verify subsequent calls go to primary + secondaryBefore := secondary.hits.Load() + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + assert.Equal(t, secondaryBefore, secondary.hits.Load(), "should be back on primary now") +} + +func TestFailover_ProbeBackFails(t *testing.T) { + primary := newConnRefusedMock() + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + // Trigger failover + _, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Wait for a few health-check ticks + time.Sleep(200 * time.Millisecond) + + // Active should still be on secondary since primary FetchStatus fails + assert.Equal(t, 1, fc.registry.Active(), "should stay on secondary when primary still down") + + // Calls should still succeed via secondary + secondaryBefore := secondary.hits.Load() + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + assert.Greater(t, secondary.hits.Load(), secondaryBefore, "should still use secondary") +} + +func TestFailover_ClosesBothClients(t *testing.T) { + var primaryClosed, secondaryClosed atomic.Bool + + primary := &mockHeimdallClient{closeFn: func() { primaryClosed.Store(true) }} + secondary := &mockHeimdallClient{closeFn: func() { secondaryClosed.Store(true) }} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.Close() + + assert.True(t, primaryClosed.Load(), "primary should be closed") + assert.True(t, secondaryClosed.Load(), "secondary should be closed") +} + +func TestFailover_PassthroughWhenPrimaryHealthy(t *testing.T) { + primary := &mockHeimdallClient{} + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 5 * time.Second + fc.probeTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Start registry and let the immediate probe cycle complete so its + // FetchStatus hits don't interfere with assertions below. + fc.ensureHealthRegistry() + time.Sleep(50 * time.Millisecond) + + primaryBefore := primary.hits.Load() + secondaryBefore := secondary.hits.Load() + + for i := 0; i < 5; i++ { + _, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + } + + assert.Equal(t, primaryBefore+5, primary.hits.Load(), "all calls should go to primary") + assert.Equal(t, secondaryBefore, secondary.hits.Load(), "secondary should not be contacted for API calls") +} + +// Integration test using real HTTP servers to verify end-to-end behavior +func TestFailover_Integration_ServiceUnavailable(t *testing.T) { + primary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + t.Cleanup(primary.Close) + + secondary := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(secondary.Close) + + primaryClient := NewHeimdallClient(primary.URL, 5*time.Second) + secondaryClient := NewHeimdallClient(secondary.URL, 5*time.Second) + + fc, err := NewMultiHeimdallClient(primaryClient, secondaryClient) + require.NoError(t, err) + + fc.attemptTimeout = 2 * time.Second + defer fc.Close() + + ctx := WithRequestType(context.Background(), SpanRequest) + + // 503 should NOT trigger failover + _, err = fc.GetSpan(ctx, 1) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrServiceUnavailable)) +} + +func TestFailover_StateSyncEvents(t *testing.T) { + primary := &mockHeimdallClient{ + stateSyncEventsFn: func(_ context.Context, _ uint64, _ int64) ([]*clerk.EventRecordWithTime, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{ + stateSyncEventsFn: func(_ context.Context, fromID uint64, to int64) ([]*clerk.EventRecordWithTime, error) { + return []*clerk.EventRecordWithTime{{EventRecord: clerk.EventRecord{ID: fromID}}}, nil + }, + } + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + events, err := fc.StateSyncEvents(context.Background(), 42, 100) + require.NoError(t, err) + require.Len(t, events, 1) + assert.Equal(t, uint64(42), events[0].ID) +} + +func TestFailover_GetLatestSpan(t *testing.T) { + primary := &mockHeimdallClient{ + getLatestSpanFn: func(_ context.Context) (*types.Span, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{ + getLatestSpanFn: func(_ context.Context) (*types.Span, error) { + return &types.Span{Id: 77}, nil + }, + } + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + span, err := fc.GetLatestSpan(context.Background()) + require.NoError(t, err) + assert.Equal(t, uint64(77), span.Id) +} + +func TestFailover_FetchCheckpoint(t *testing.T) { + primary := &mockHeimdallClient{ + fetchCheckpointFn: func(_ context.Context, _ int64) (*checkpoint.Checkpoint, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + cp, err := fc.FetchCheckpoint(context.Background(), 5) + require.NoError(t, err) + require.NotNil(t, cp) +} + +func TestFailover_FetchCheckpointCount(t *testing.T) { + primary := &mockHeimdallClient{ + fetchCheckpointCntFn: func(_ context.Context) (int64, error) { + return 0, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + count, err := fc.FetchCheckpointCount(context.Background()) + require.NoError(t, err) + assert.Equal(t, int64(10), count) +} + +func TestFailover_FetchMilestone(t *testing.T) { + primary := &mockHeimdallClient{ + fetchMilestoneFn: func(_ context.Context) (*milestone.Milestone, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + ms, err := fc.FetchMilestone(context.Background()) + require.NoError(t, err) + require.NotNil(t, ms) +} + +func TestFailover_FetchMilestoneCount(t *testing.T) { + primary := &mockHeimdallClient{ + fetchMilestoneCntFn: func(_ context.Context) (int64, error) { + return 0, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + count, err := fc.FetchMilestoneCount(context.Background()) + require.NoError(t, err) + assert.Equal(t, int64(5), count) +} + +func TestFailover_FetchStatus(t *testing.T) { + primary := &mockHeimdallClient{ + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + status, err := fc.FetchStatus(context.Background()) + require.NoError(t, err) + require.NotNil(t, status) +} + +func TestFailover_SwitchOnPrimarySubContextError(t *testing.T) { + tests := []struct { + name string + primaryFn func(ctx context.Context, _ uint64) (*types.Span, error) + }{ + { + name: "DeadlineExceeded", + primaryFn: func(ctx context.Context, _ uint64) (*types.Span, error) { + <-ctx.Done() + return nil, ctx.Err() + }, + }, + { + name: "Canceled", + primaryFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, context.Canceled + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + primary := &mockHeimdallClient{getSpanFn: tt.primaryFn} + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + span, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + require.NotNil(t, span) + assert.GreaterOrEqual(t, primary.hits.Load(), int32(1), "primary should have been tried") + assert.GreaterOrEqual(t, secondary.hits.Load(), int32(1), "should failover on sub-context error") + }) + } +} + +func TestIsFailoverError(t *testing.T) { + ctx := context.Background() + + // Transport errors should trigger failover + netErr := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + assert.True(t, isFailoverError(netErr, ctx), "net.Error should trigger failover") + + // ErrNoResponse should trigger failover + assert.True(t, isFailoverError(ErrNoResponse, ctx), "ErrNoResponse should trigger failover") + + // 5xx HTTP errors should trigger failover; the server is unhealthy + assert.True(t, isFailoverError(&HTTPStatusError{StatusCode: 500}, ctx), "5xx should trigger failover") + assert.True(t, isFailoverError(fmt.Errorf("wrapped: %w", &HTTPStatusError{StatusCode: 502}), ctx), "wrapped 5xx should trigger failover") + + // 4xx HTTP errors should NOT trigger failover; a logical error will be the same on every node + assert.False(t, isFailoverError(&HTTPStatusError{StatusCode: 400}, ctx), "4xx should not trigger failover") + assert.False(t, isFailoverError(&HTTPStatusError{StatusCode: 404}, ctx), "4xx should not trigger failover") + + // DeadlineExceeded with live caller ctx should trigger failover + assert.True(t, isFailoverError(context.DeadlineExceeded, ctx), "DeadlineExceeded should trigger failover when caller ctx is alive") + + // Canceled with live caller ctx should trigger failover (sub-context was canceled, not the caller) + assert.True(t, isFailoverError(context.Canceled, ctx), "Canceled should trigger failover when caller ctx is alive") + + // ErrShutdownDetected should NOT trigger failover + assert.False(t, isFailoverError(ErrShutdownDetected, ctx), "ErrShutdownDetected should not trigger failover") + + // ErrServiceUnavailable should NOT trigger failover + assert.False(t, isFailoverError(ErrServiceUnavailable, ctx), "ErrServiceUnavailable should not trigger failover") + + // Caller context cancelled should NOT trigger failover + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + assert.False(t, isFailoverError(context.DeadlineExceeded, cancelledCtx), "should not failover when caller ctx is done") + + // nil error should not trigger failover + assert.False(t, isFailoverError(nil, ctx), "nil error should not trigger failover") +} + +func TestFailover_ThreeClients_CascadeToTertiary(t *testing.T) { + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + tertiary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary, tertiary) + defer fc.Close() + + span, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + require.NotNil(t, span) + + assert.GreaterOrEqual(t, primary.hits.Load(), int32(1), "primary should have been tried") + assert.GreaterOrEqual(t, secondary.hits.Load(), int32(1), "secondary should have been tried") + assert.GreaterOrEqual(t, tertiary.hits.Load(), int32(1), "tertiary should have been called") +} + +func TestFailover_AllClientsFail(t *testing.T) { + connErr := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { return nil, connErr }, + } + secondary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { return nil, connErr }, + } + tertiary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { return nil, connErr }, + } + + fc := newInstantMulti(primary, secondary, tertiary) + defer fc.Close() + + _, err := fc.GetSpan(context.Background(), 1) + require.Error(t, err) +} + +func TestFailover_ThreeClients_ProbeBackToPrimary(t *testing.T) { + primaryDown := atomic.Bool{} + primaryDown.Store(true) + + primary := newToggleMock(&primaryDown) + secondary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, testConnErr + }, + } + tertiary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary, tertiary) + defer fc.Close() + + // Trigger cascade to tertiary + _, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Bring primary back + primaryDown.Store(false) + + // Wait for health registry to promote back to primary + require.Eventually(t, func() bool { + return fc.registry.Active() == 0 + }, 2*time.Second, 20*time.Millisecond, "health registry should promote back to primary") + + // Verify we're back on primary + tertiaryBefore := tertiary.hits.Load() + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + assert.Equal(t, tertiaryBefore, tertiary.hits.Load(), "should be back on primary now") +} + +// Active client returns non-failover error: should return directly, no cascade. +func TestFailover_ActiveNonFailoverError(t *testing.T) { + primary := &mockHeimdallClient{} + secondary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, ErrShutdownDetected + }, + } + tertiary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary, tertiary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Force onto secondary + fc.registry.SetActive(1) + + _, err = fc.GetSpan(context.Background(), 1) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrShutdownDetected)) + assert.Equal(t, int32(0), tertiary.hits.Load(), "should not cascade to tertiary on non-failover error") +} + +// Active client returns failover error: cascade should try by priority. +func TestFailover_ActiveFailoverError_CascadesToNext(t *testing.T) { + // Primary also fails so cascade doesn't land there. + primary := newConnRefusedMock() + secondary := newConnRefusedMock() + tertiary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary, tertiary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.probeTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour // prevent background probes from promoting + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Force onto secondary + fc.registry.SetActive(1) + + span, getErr := fc.GetSpan(context.Background(), 1) + require.NoError(t, getErr) + require.NotNil(t, span) + assert.GreaterOrEqual(t, tertiary.hits.Load(), int32(1), "should cascade to tertiary") + + assert.Equal(t, 2, fc.registry.Active(), "active should switch to tertiary") +} + +func TestFailover_ClosesAllClients(t *testing.T) { + var closed [3]atomic.Bool + + clients := make([]Endpoint, 3) + for i := range clients { + idx := i + clients[i] = &mockHeimdallClient{closeFn: func() { closed[idx].Store(true) }} + } + + fc, err := NewMultiHeimdallClient(clients...) + require.NoError(t, err) + + fc.Close() + + for i := range closed { + assert.True(t, closed[i].Load(), "client %d should be closed", i) + } +} + +func TestFailover_HealthCheckPromotesHighestPriority(t *testing.T) { + primaryDown := atomic.Bool{} + primaryDown.Store(true) + + secondaryDown := atomic.Bool{} + secondaryDown.Store(true) + + primary := newProbeToggleMock(&primaryDown) + secondary := newProbeToggleMock(&secondaryDown) + tertiary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary, tertiary) + defer fc.Close() + + // Trigger cascade to tertiary + _, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Bring secondary back first + secondaryDown.Store(false) + + require.Eventually(t, func() bool { + return fc.registry.Active() == 1 + }, 2*time.Second, 20*time.Millisecond, "should promote to secondary") + + // Now bring primary back + primaryDown.Store(false) + + require.Eventually(t, func() bool { + return fc.registry.Active() == 0 + }, 2*time.Second, 20*time.Millisecond, "should promote to primary") +} + +func TestFailover_HealthRegistryRespectsClose(t *testing.T) { + primary := &mockHeimdallClient{ + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 50 * time.Millisecond + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + + // Close should stop the health registry goroutine + fc.Close() + + // No goroutine should be running after close — verify by checking + // that probe counts don't increase after close. + probesBefore := failoverProbeAttempts.Snapshot().Count() + time.Sleep(200 * time.Millisecond) + probesAfter := failoverProbeAttempts.Snapshot().Count() + + assert.Equal(t, probesBefore, probesAfter, "no probes should run after Close") +} + +// --- New health registry tests --- + +func TestRegistry_ConsecutiveThreshold(t *testing.T) { + probeCount := atomic.Int32{} + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, testConnErr + }, + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + probeCount.Add(1) + return &ctypes.SyncInfo{}, nil + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 50 * time.Millisecond + fc.registry.ConsecutiveThreshold = 3 // need 3 consecutive successes + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Trigger failover + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + assert.Equal(t, 1, fc.registry.Active(), "should be on secondary") + + // Wait for enough probes to pass the threshold + require.Eventually(t, func() bool { + return probeCount.Load() >= 3 + }, 2*time.Second, 20*time.Millisecond, "should probe primary at least 3 times") + + // Should eventually promote after threshold met + require.Eventually(t, func() bool { + return fc.registry.Active() == 0 + }, 2*time.Second, 20*time.Millisecond, "should promote after consecutive threshold met") +} + +func TestRegistry_PromotionCooldown(t *testing.T) { + primaryDown := atomic.Bool{} + primaryDown.Store(true) + + primary := newProbeToggleMock(&primaryDown) + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 50 * time.Millisecond + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 500 * time.Millisecond // 500ms cooldown + defer fc.Close() + + // Trigger failover + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Bring primary back + primaryDown.Store(false) + + // Wait for at least one probe to succeed — primary should be healthy but not promoted yet + time.Sleep(150 * time.Millisecond) + assert.Equal(t, 1, fc.registry.Active(), "should not promote before cooldown") + + // Wait for cooldown to pass and promotion to happen + require.Eventually(t, func() bool { + return fc.registry.Active() == 0 + }, 3*time.Second, 20*time.Millisecond, "should promote after cooldown passes") +} + +func TestRegistry_FlappingPrevention(t *testing.T) { + callCount := atomic.Int32{} + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return nil, testConnErr + }, + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + n := callCount.Add(1) + // Alternate: success, fail, success, fail... + if n%2 == 0 { + return nil, testConnErr + } + return &ctypes.SyncInfo{}, nil + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 50 * time.Millisecond + fc.registry.ConsecutiveThreshold = 3 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Trigger failover + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Wait for several probe cycles + time.Sleep(500 * time.Millisecond) + + // Primary should never reach healthy because alternating success/fail + // never reaches 3 consecutive successes. + assert.Equal(t, 1, fc.registry.Active(), "should stay on secondary — flapping primary never reaches threshold") +} + +func TestRegistry_InformedCascade_SkipsUnhealthy(t *testing.T) { + primary := newConnRefusedMock() + secondary := newConnRefusedMock() + tertiary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary, tertiary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Mark secondary as unhealthy in the registry + fc.registry.SetHealth(1, EndpointHealth{Healthy: false}) + + // Trigger failover from primary + secondaryHitsBefore := secondary.hits.Load() + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + + // Secondary should not have been tried for the GetSpan call since it's unhealthy, + // but it may be tried in the last-resort pass. The key thing is that tertiary succeeds. + assert.Equal(t, 2, fc.registry.Active(), "should end up on tertiary") + + _ = secondaryHitsBefore +} + +func TestRegistry_InformedCascade_TriesByPriority(t *testing.T) { + connErr := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + + // Track call order + var callOrder []int + var orderMu sync.Mutex + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + orderMu.Lock() + callOrder = append(callOrder, 0) + orderMu.Unlock() + return &types.Span{Id: 1}, nil + }, + } + secondary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + orderMu.Lock() + callOrder = append(callOrder, 1) + orderMu.Unlock() + return nil, connErr + }, + } + tertiary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + orderMu.Lock() + callOrder = append(callOrder, 2) + orderMu.Unlock() + return nil, connErr + }, + } + + fc, err := NewMultiHeimdallClient(primary, secondary, tertiary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Force active to index 1 (secondary); primary (index 0) is healthy + fc.registry.SetActive(1) + fc.registry.SetHealth(0, EndpointHealth{Healthy: true, HealthySince: time.Now().Add(-1 * time.Hour)}) + fc.registry.SetHealth(1, EndpointHealth{Healthy: true}) + fc.registry.SetHealth(2, EndpointHealth{Healthy: true}) + + span, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + require.NotNil(t, span) + + // Cascade should try primary (index 0) before tertiary (index 2) + assert.Equal(t, 0, fc.registry.Active(), "should cascade to primary (highest priority)") +} + +func TestRegistry_ProactiveSwitchOnActiveUnhealthy(t *testing.T) { + primaryDown := atomic.Bool{} + primaryDown.Store(false) + + primary := &mockHeimdallClient{ + fetchStatusFn: func(_ context.Context) (*ctypes.SyncInfo, error) { + if primaryDown.Load() { + return nil, &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + } + return &ctypes.SyncInfo{}, nil + }, + } + secondary := &mockHeimdallClient{} + + fc := newInstantMulti(primary, secondary) + defer fc.Close() + + // Start the health registry (normally started on first API call). + fc.ensureHealthRegistry() + + // Verify we start on primary + assert.Equal(t, 0, fc.registry.Active(), "should start on primary") + + // Now make primary go down — the health registry should detect and switch + primaryDown.Store(true) + + require.Eventually(t, func() bool { + return fc.registry.Active() == 1 + }, 2*time.Second, 20*time.Millisecond, "health registry should proactively switch to secondary") +} + +func TestRegistry_CascadeFallsBackToUnhealthy(t *testing.T) { + connErr := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { return nil, connErr }, + } + // Secondary is marked unhealthy but actually works + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.probeTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Start registry and let the immediate probe complete before setting up + // the test state, otherwise the probe can mark secondary healthy. + fc.ensureHealthRegistry() + time.Sleep(50 * time.Millisecond) + + // Mark secondary as unhealthy + fc.registry.SetHealth(1, EndpointHealth{Healthy: false}) + + // Primary fails, cascade should fall back to unhealthy secondary as last resort + span, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + require.NotNil(t, span) + + assert.Equal(t, 1, fc.registry.Active(), "should fall back to unhealthy secondary as last resort") +} + +func TestRegistry_MarkUnhealthyOnRealFailure(t *testing.T) { + connErr := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { return nil, connErr }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + // Primary starts as healthy + snap := fc.registry.HealthSnapshot() + assert.True(t, snap[0].Healthy, "primary should start healthy") + + // Trigger a real request that fails on primary + _, err = fc.GetSpan(context.Background(), 1) + require.NoError(t, err) // succeeds via secondary + + // Primary should now be marked unhealthy + snap = fc.registry.HealthSnapshot() + assert.False(t, snap[0].Healthy, "primary should be marked unhealthy after real failure") + assert.Equal(t, 0, snap[0].ConsecutiveSuccess, "consecutive success should be reset") +} + +func TestFailover_ProbeUsesProbeTimeout(t *testing.T) { + // Verify that probes use the short probeTimeout, not the long attemptTimeout. + // A probe against a hanging endpoint should fail within probeTimeout, not + // wait for attemptTimeout. + primary := &mockHeimdallClient{ + fetchStatusFn: func(ctx context.Context) (*ctypes.SyncInfo, error) { + // Hang until context expires. + <-ctx.Done() + return nil, ctx.Err() + }, + } + secondary := &mockHeimdallClient{} + + fc, err := NewMultiHeimdallClient(primary, secondary) + require.NoError(t, err) + + fc.attemptTimeout = 10 * time.Second // long — should NOT be used for probes + fc.probeTimeout = 200 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 0 + defer fc.Close() + + start := time.Now() + fc.registry.Start() + + // Wait for the immediate probe cycle to complete. + require.Eventually(t, func() bool { + snap := fc.registry.HealthSnapshot() + return !snap[0].Healthy || snap[0].LastErr != nil + }, 2*time.Second, 20*time.Millisecond, "probe should complete") + + elapsed := time.Since(start) + assert.Less(t, elapsed, 2*time.Second, "probe should complete within probeTimeout, not attemptTimeout") +} + +func TestRegistry_InformedCascade_RespectsCooldown(t *testing.T) { + connErr := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")} + + // Primary (index 0): healthy but NOT cooled (recently became healthy) + // Secondary (index 1): fails (active) + // Tertiary (index 2): healthy AND cooled + + primary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return &types.Span{Id: 1}, nil + }, + } + secondary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { return nil, connErr }, + } + tertiary := &mockHeimdallClient{ + getSpanFn: func(_ context.Context, _ uint64) (*types.Span, error) { + return &types.Span{Id: 3}, nil + }, + } + + fc, err := NewMultiHeimdallClient(primary, secondary, tertiary) + require.NoError(t, err) + + fc.attemptTimeout = 100 * time.Millisecond + fc.registry.HealthCheckInterval = 1 * time.Hour + fc.registry.ConsecutiveThreshold = 1 + fc.registry.PromotionCooldown = 1 * time.Hour // long cooldown + defer fc.Close() + + // Set up health states + fc.registry.SetActive(1) + fc.registry.SetHealth(0, EndpointHealth{Healthy: true, HealthySince: time.Now()}) // NOT cooled + fc.registry.SetHealth(1, EndpointHealth{Healthy: true}) // active, will fail + fc.registry.SetHealth(2, EndpointHealth{Healthy: true, HealthySince: time.Now().Add(-2 * time.Hour)}) // cooled + + span, err := fc.GetSpan(context.Background(), 1) + require.NoError(t, err) + require.NotNil(t, span) + + // Should prefer tertiary (cooled) over primary (uncooled) + assert.Equal(t, 2, fc.registry.Active(), "should prefer cooled tertiary over uncooled primary") +} diff --git a/consensus/bor/heimdall/failover_metrics.go b/consensus/bor/heimdall/failover_metrics.go new file mode 100644 index 0000000000..482fb6fa29 --- /dev/null +++ b/consensus/bor/heimdall/failover_metrics.go @@ -0,0 +1,21 @@ +package heimdall + +import "github.com/ethereum/go-ethereum/metrics" + +var ( + // HTTP/gRPC failover metrics (used within this package) + failoverSwitchCounter = metrics.NewRegisteredCounter("client/failover/switches", nil) + failoverActiveGauge = metrics.NewRegisteredGauge("client/failover/active", nil) + failoverProbeAttempts = metrics.NewRegisteredCounter("client/failover/probe/attempts", nil) + failoverProbeSuccesses = metrics.NewRegisteredCounter("client/failover/probe/successes", nil) + failoverHealthyEndpoints = metrics.NewRegisteredGauge("client/failover/healthy_endpoints", nil) + failoverProactiveSwitches = metrics.NewRegisteredCounter("client/failover/proactive_switches", nil) + + // WS failover metrics (exported for use by heimdallws package) + FailoverWSSwitchCounter = metrics.NewRegisteredCounter("client/failover/ws/switches", nil) + FailoverWSActiveGauge = metrics.NewRegisteredGauge("client/failover/ws/active", nil) + FailoverWSProbeAttempts = metrics.NewRegisteredCounter("client/failover/ws/probe/attempts", nil) + FailoverWSProbeSuccesses = metrics.NewRegisteredCounter("client/failover/ws/probe/successes", nil) + FailoverWSHealthyEndpoints = metrics.NewRegisteredGauge("client/failover/ws/healthy_endpoints", nil) + FailoverWSProactiveSwitches = metrics.NewRegisteredCounter("client/failover/ws/proactive_switches", nil) +) diff --git a/consensus/bor/heimdall/health_registry.go b/consensus/bor/heimdall/health_registry.go new file mode 100644 index 0000000000..de61698584 --- /dev/null +++ b/consensus/bor/heimdall/health_registry.go @@ -0,0 +1,387 @@ +package heimdall + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" +) + +// EndpointHealth tracks the health state of a single endpoint. +type EndpointHealth struct { + Healthy bool + ConsecutiveSuccess int + HealthySince time.Time // when consecutive threshold was reached + LastErr error +} + +// RegistryMetrics holds the metrics counters/gauges that a HealthRegistry reports to. +// Nil pointers are safe — the registry checks before calling. +type RegistryMetrics struct { + ProbeAttempts *metrics.Counter + ProbeSuccesses *metrics.Counter + ProactiveSwitches *metrics.Counter + ActiveGauge *metrics.Gauge + HealthyEndpoints *metrics.Gauge +} + +// HealthRegistry is a shared health state machine for N endpoints. +// It runs a background goroutine that probes all endpoints, promotes +// higher-priority endpoints when healthy+cooled, and proactively switches +// away from unhealthy active endpoints. +type HealthRegistry struct { + mu sync.Mutex + health []EndpointHealth + active int + n int + + // Exported config fields — set after construction, before Start(). + HealthCheckInterval time.Duration + ConsecutiveThreshold int + PromotionCooldown time.Duration + + probeFunc func(i int) error + onSwitch func(from, to int) // called outside mu to avoid lock-ordering issues + + metrics RegistryMetrics + + quit chan struct{} + done chan struct{} // closed when run() exits + closeOnce sync.Once + startOnce sync.Once +} + +// NewHealthRegistry creates a registry for n endpoints. +// probeFunc is called for each endpoint index to test reachability. +// onSwitch (optional) is called outside the registry lock when the active +// endpoint changes due to promotion, proactive switch, or SetActive. +func NewHealthRegistry(n int, probeFunc func(int) error, onSwitch func(from, to int), m RegistryMetrics) *HealthRegistry { + health := make([]EndpointHealth, n) + // Primary starts as healthy; others start unhealthy. + health[0] = EndpointHealth{Healthy: true} + + return &HealthRegistry{ + health: health, + n: n, + HealthCheckInterval: defaultHealthCheckInterval, + ConsecutiveThreshold: defaultConsecutiveThreshold, + PromotionCooldown: defaultPromotionCooldown, + probeFunc: probeFunc, + onSwitch: onSwitch, + metrics: m, + quit: make(chan struct{}), + done: make(chan struct{}), + } +} + +// Active returns the index of the currently active endpoint. +func (r *HealthRegistry) Active() int { + r.mu.Lock() + defer r.mu.Unlock() + + return r.active +} + +// SetActive sets the active endpoint index, updates the gauge, and calls onSwitch +// if the active endpoint changed. The caller must NOT hold r.mu. +func (r *HealthRegistry) SetActive(i int) { + r.mu.Lock() + prev := r.active + r.active = i + + if r.metrics.ActiveGauge != nil { + r.metrics.ActiveGauge.Update(int64(i)) + } + r.mu.Unlock() + + // Call onSwitch outside r.mu to avoid lock-ordering deadlock. + // The WS client's onWSSwitch callback acquires c.mu, so calling it + // under r.mu would create a registry.mu → c.mu path that conflicts + // with the c.mu → registry.mu path in tryUntilSubscribeMilestoneEvents. + if prev != i && r.onSwitch != nil { + r.onSwitch(prev, i) + } +} + +// MarkUnhealthy resets the health state of endpoint i to unhealthy. +func (r *HealthRegistry) MarkUnhealthy(i int, err error) { + r.mu.Lock() + defer r.mu.Unlock() + + r.health[i].ConsecutiveSuccess = 0 + r.health[i].Healthy = false + r.health[i].LastErr = err +} + +// MarkSuccess increments the consecutive success count for endpoint i and +// transitions it to healthy if the threshold is met. +func (r *HealthRegistry) MarkSuccess(i int) { + r.mu.Lock() + defer r.mu.Unlock() + + r.health[i].ConsecutiveSuccess++ + r.health[i].LastErr = nil + + if r.health[i].ConsecutiveSuccess >= r.ConsecutiveThreshold && !r.health[i].Healthy { + r.health[i].Healthy = true + r.health[i].HealthySince = time.Now() + } +} + +// HealthSnapshot returns a copy of all endpoint health states. +func (r *HealthRegistry) HealthSnapshot() []EndpointHealth { + r.mu.Lock() + defer r.mu.Unlock() + + snap := make([]EndpointHealth, r.n) + copy(snap, r.health) + + return snap +} + +// SetHealth directly overrides the health state of endpoint i. +// Intended for tests that need to manipulate state. +func (r *HealthRegistry) SetHealth(i int, h EndpointHealth) { + r.mu.Lock() + defer r.mu.Unlock() + + r.health[i] = h +} + +// Start lazily starts the background health-check goroutine via startOnce. +func (r *HealthRegistry) Start() { + r.startOnce.Do(func() { + go r.run() + }) +} + +// Stop closes the quit channel and waits for the background goroutine to exit. +func (r *HealthRegistry) Stop() { + // If Start() was never called, close done so the wait below doesn't block. + r.startOnce.Do(func() { + close(r.done) + }) + + r.closeOnce.Do(func() { + close(r.quit) + }) + + <-r.done +} + +// run is the background goroutine: probe → promote → proactive switch. +func (r *HealthRegistry) run() { + defer close(r.done) + + // Run an immediate probe cycle so a down primary is detected within + // seconds of boot rather than waiting for the first ticker fire. + r.probeAll() + r.maybePromote() + r.maybeProactiveSwitch() + + ticker := time.NewTicker(r.HealthCheckInterval) + defer ticker.Stop() + + for { + select { + case <-r.quit: + return + case <-ticker.C: + } + + r.probeAll() + r.maybePromote() + r.maybeProactiveSwitch() + } +} + +// probeAll probes every endpoint concurrently and updates health state. +// Each goroutine applies its own result immediately so that a request +// arriving mid-cycle (via callWithFailover → HealthSnapshot) sees fresh +// data for already-completed probes rather than stale data for all of them. +func (r *HealthRegistry) probeAll() { + // Check for shutdown before launching probes. + select { + case <-r.quit: + return + default: + } + + var wg sync.WaitGroup + wg.Add(r.n) + + for i := 0; i < r.n; i++ { + if r.metrics.ProbeAttempts != nil { + r.metrics.ProbeAttempts.Inc(1) + } + + go func(idx int) { + defer wg.Done() + + err := r.probeFunc(idx) + + // Apply this probe's result immediately. + r.mu.Lock() + if err == nil { + r.health[idx].ConsecutiveSuccess++ + r.health[idx].LastErr = nil + + if r.health[idx].ConsecutiveSuccess >= r.ConsecutiveThreshold && !r.health[idx].Healthy { + r.health[idx].Healthy = true + r.health[idx].HealthySince = time.Now() + } + + if r.metrics.ProbeSuccesses != nil { + r.metrics.ProbeSuccesses.Inc(1) + } + } else { + r.health[idx].ConsecutiveSuccess = 0 + r.health[idx].Healthy = false + r.health[idx].LastErr = err + } + r.mu.Unlock() + }(i) + } + + wg.Wait() + + // Update gauge after all probes complete — needs to scan all results. + select { + case <-r.quit: + return + default: + } + + if r.metrics.HealthyEndpoints != nil { + r.mu.Lock() + healthyCount := int64(0) + for i := 0; i < r.n; i++ { + if r.health[i].Healthy { + healthyCount++ + } + } + r.mu.Unlock() + + r.metrics.HealthyEndpoints.Update(healthyCount) + } +} + +// maybePromote checks if a higher-priority endpoint (index < active) is healthy +// and has passed cooldown. If yes, promotes to the highest-priority qualified endpoint. +func (r *HealthRegistry) maybePromote() { + var prev, next int + doSwitch := false + + r.mu.Lock() + + if r.active != 0 { + for i := 0; i < r.active; i++ { + if r.health[i].Healthy && time.Since(r.health[i].HealthySince) >= r.PromotionCooldown { + prev = r.active + next = i + r.active = i + doSwitch = true + + if r.metrics.ActiveGauge != nil { + r.metrics.ActiveGauge.Update(int64(i)) + } + + if r.metrics.ProactiveSwitches != nil { + r.metrics.ProactiveSwitches.Inc(1) + } + + break + } + } + } + + r.mu.Unlock() + + if doSwitch { + log.Info("Health registry: promoted to higher-priority endpoint", + "index", next, "previous", prev) + + if r.onSwitch != nil { + r.onSwitch(prev, next) + } + } +} + +// maybeProactiveSwitch detects if the active endpoint is unhealthy and switches +// to the highest-priority healthy endpoint. +func (r *HealthRegistry) maybeProactiveSwitch() { + var prev, next int + doSwitch := false + var logMsg string + + r.mu.Lock() + + if r.health[r.active].Healthy { + r.mu.Unlock() + return + } + + // Active is unhealthy. Find the best alternative. + // Pass 1: healthy + cooled. + for i := 0; i < r.n; i++ { + if i == r.active { + continue + } + + if r.health[i].Healthy && time.Since(r.health[i].HealthySince) >= r.PromotionCooldown { + prev = r.active + next = i + r.active = i + doSwitch = true + logMsg = "Health registry: proactive switch (active unhealthy, cooled target)" + + if r.metrics.ActiveGauge != nil { + r.metrics.ActiveGauge.Update(int64(i)) + } + + if r.metrics.ProactiveSwitches != nil { + r.metrics.ProactiveSwitches.Inc(1) + } + + break + } + } + + // Pass 2: healthy but NOT cooled (emergency). + if !doSwitch { + for i := 0; i < r.n; i++ { + if i == r.active { + continue + } + + if r.health[i].Healthy { + prev = r.active + next = i + r.active = i + doSwitch = true + logMsg = "Health registry: proactive switch (active unhealthy, uncooled target)" + + if r.metrics.ActiveGauge != nil { + r.metrics.ActiveGauge.Update(int64(i)) + } + + if r.metrics.ProactiveSwitches != nil { + r.metrics.ProactiveSwitches.Inc(1) + } + + break + } + } + } + + r.mu.Unlock() + + if doSwitch { + log.Warn(logMsg, "from", prev, "to", next) + + if r.onSwitch != nil { + r.onSwitch(prev, next) + } + } +} diff --git a/consensus/bor/heimdall/health_registry_test.go b/consensus/bor/heimdall/health_registry_test.go new file mode 100644 index 0000000000..9761dd05a1 --- /dev/null +++ b/consensus/bor/heimdall/health_registry_test.go @@ -0,0 +1,338 @@ +package heimdall + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHealthRegistry_Constructor_PrimaryHealthy(t *testing.T) { + r := NewHealthRegistry(3, func(i int) error { return nil }, nil, RegistryMetrics{}) + + snap := r.HealthSnapshot() + assert.Len(t, snap, 3) + assert.True(t, snap[0].Healthy, "primary should start healthy") + assert.False(t, snap[1].Healthy, "secondary should start unhealthy") + assert.False(t, snap[2].Healthy, "tertiary should start unhealthy") + assert.Equal(t, 0, r.Active()) +} + +func TestHealthRegistry_MarkUnhealthy(t *testing.T) { + r := NewHealthRegistry(2, func(i int) error { return nil }, nil, RegistryMetrics{}) + + r.MarkUnhealthy(0, errors.New("down")) + + snap := r.HealthSnapshot() + assert.False(t, snap[0].Healthy) + assert.Equal(t, 0, snap[0].ConsecutiveSuccess) + assert.EqualError(t, snap[0].LastErr, "down") +} + +func TestHealthRegistry_MarkSuccess_Transitions(t *testing.T) { + r := NewHealthRegistry(2, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.ConsecutiveThreshold = 3 + + // Endpoint 1 starts unhealthy. + snap := r.HealthSnapshot() + assert.False(t, snap[1].Healthy) + + // Two successes: still unhealthy. + r.MarkSuccess(1) + r.MarkSuccess(1) + snap = r.HealthSnapshot() + assert.False(t, snap[1].Healthy) + assert.Equal(t, 2, snap[1].ConsecutiveSuccess) + + // Third success: transitions to healthy. + r.MarkSuccess(1) + snap = r.HealthSnapshot() + assert.True(t, snap[1].Healthy) + assert.Equal(t, 3, snap[1].ConsecutiveSuccess) + assert.False(t, snap[1].HealthySince.IsZero()) +} + +func TestHealthRegistry_MarkSuccess_ResetByFailure(t *testing.T) { + r := NewHealthRegistry(2, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.ConsecutiveThreshold = 3 + + r.MarkSuccess(1) + r.MarkSuccess(1) + r.MarkUnhealthy(1, errors.New("fail")) + + snap := r.HealthSnapshot() + assert.False(t, snap[1].Healthy) + assert.Equal(t, 0, snap[1].ConsecutiveSuccess) + + // Need 3 more successes after reset. + r.MarkSuccess(1) + snap = r.HealthSnapshot() + assert.False(t, snap[1].Healthy) +} + +func TestHealthRegistry_SetActive_CallsOnSwitch(t *testing.T) { + var switchFrom, switchTo int + called := false + + r := NewHealthRegistry(2, func(i int) error { return nil }, func(from, to int) { + called = true + switchFrom = from + switchTo = to + }, RegistryMetrics{}) + + r.SetActive(1) + assert.True(t, called) + assert.Equal(t, 0, switchFrom) + assert.Equal(t, 1, switchTo) + assert.Equal(t, 1, r.Active()) +} + +func TestHealthRegistry_SetActive_NoCallOnSameIndex(t *testing.T) { + called := false + r := NewHealthRegistry(2, func(i int) error { return nil }, func(from, to int) { + called = true + }, RegistryMetrics{}) + + r.SetActive(0) // same as current + assert.False(t, called, "onSwitch should not be called when active doesn't change") +} + +func TestHealthRegistry_SetHealth(t *testing.T) { + r := NewHealthRegistry(2, func(i int) error { return nil }, nil, RegistryMetrics{}) + + h := EndpointHealth{ + Healthy: true, + ConsecutiveSuccess: 5, + HealthySince: time.Now().Add(-1 * time.Hour), + } + r.SetHealth(1, h) + + snap := r.HealthSnapshot() + assert.True(t, snap[1].Healthy) + assert.Equal(t, 5, snap[1].ConsecutiveSuccess) +} + +func TestHealthRegistry_ProbeAll(t *testing.T) { + probeResults := []error{nil, errors.New("fail"), nil} + probeCount := atomic.Int32{} + + r := NewHealthRegistry(3, func(i int) error { + probeCount.Add(1) + return probeResults[i] + }, nil, RegistryMetrics{}) + r.ConsecutiveThreshold = 1 + + r.probeAll() + + assert.Equal(t, int32(3), probeCount.Load()) + + snap := r.HealthSnapshot() + // Index 0 was already healthy, stays healthy. + assert.True(t, snap[0].Healthy) + // Index 1 failed: unhealthy. + assert.False(t, snap[1].Healthy) + assert.EqualError(t, snap[1].LastErr, "fail") + // Index 2 succeeded once with threshold=1: becomes healthy. + assert.True(t, snap[2].Healthy) +} + +func TestHealthRegistry_MaybePromote(t *testing.T) { + r := NewHealthRegistry(3, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.PromotionCooldown = 0 + r.ConsecutiveThreshold = 1 + + // Set active to 2, mark index 0 as unhealthy, make index 1 healthy+cooled. + r.SetActive(2) + r.SetHealth(0, EndpointHealth{Healthy: false}) + r.SetHealth(1, EndpointHealth{ + Healthy: true, + HealthySince: time.Now().Add(-1 * time.Hour), + }) + + r.maybePromote() + + assert.Equal(t, 1, r.Active(), "should promote to index 1") +} + +func TestHealthRegistry_MaybePromote_RespectsOrder(t *testing.T) { + r := NewHealthRegistry(3, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.PromotionCooldown = 0 + + // Active at 2, both 0 and 1 healthy — should promote to 0 (highest priority). + r.SetActive(2) + r.SetHealth(0, EndpointHealth{Healthy: true, HealthySince: time.Now().Add(-1 * time.Hour)}) + r.SetHealth(1, EndpointHealth{Healthy: true, HealthySince: time.Now().Add(-1 * time.Hour)}) + + r.maybePromote() + + assert.Equal(t, 0, r.Active(), "should promote to index 0 (highest priority)") +} + +func TestHealthRegistry_MaybePromote_RespectsCooldown(t *testing.T) { + r := NewHealthRegistry(2, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.PromotionCooldown = 1 * time.Hour + + // Active at 1, index 0 healthy but recently (not cooled). + r.SetActive(1) + r.SetHealth(0, EndpointHealth{Healthy: true, HealthySince: time.Now()}) + + r.maybePromote() + + assert.Equal(t, 1, r.Active(), "should not promote — cooldown not met") +} + +func TestHealthRegistry_MaybeProactiveSwitch_CooledFirst(t *testing.T) { + r := NewHealthRegistry(3, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.PromotionCooldown = 0 + + // Active at 0, mark it unhealthy. Index 2 is healthy+cooled. + r.SetHealth(0, EndpointHealth{Healthy: false}) + r.SetHealth(2, EndpointHealth{Healthy: true, HealthySince: time.Now().Add(-1 * time.Hour)}) + + r.maybeProactiveSwitch() + + assert.Equal(t, 2, r.Active(), "should switch to cooled healthy endpoint") +} + +func TestHealthRegistry_MaybeProactiveSwitch_UncooledFallback(t *testing.T) { + r := NewHealthRegistry(3, func(i int) error { return nil }, nil, RegistryMetrics{}) + r.PromotionCooldown = 1 * time.Hour + + // Active at 0, mark it unhealthy. Index 1 is healthy but NOT cooled. + r.SetHealth(0, EndpointHealth{Healthy: false}) + r.SetHealth(1, EndpointHealth{Healthy: true, HealthySince: time.Now()}) // not cooled + + r.maybeProactiveSwitch() + + assert.Equal(t, 1, r.Active(), "should fall back to uncooled healthy endpoint") +} + +func TestHealthRegistry_MaybeProactiveSwitch_NoHealthy(t *testing.T) { + r := NewHealthRegistry(3, func(i int) error { return nil }, nil, RegistryMetrics{}) + + // All unhealthy. + r.SetHealth(0, EndpointHealth{Healthy: false}) + r.SetHealth(1, EndpointHealth{Healthy: false}) + r.SetHealth(2, EndpointHealth{Healthy: false}) + + r.maybeProactiveSwitch() + + assert.Equal(t, 0, r.Active(), "should stay on 0 when no alternatives are healthy") +} + +func TestHealthRegistry_ImmediateProbeOnStart(t *testing.T) { + probeCount := atomic.Int32{} + + r := NewHealthRegistry(2, func(i int) error { + probeCount.Add(1) + return nil + }, nil, RegistryMetrics{}) + r.HealthCheckInterval = 10 * time.Second // long interval — should NOT gate first probe + + r.Start() + defer r.Stop() + + // The first probe cycle should fire immediately, not after HealthCheckInterval. + require.Eventually(t, func() bool { + return probeCount.Load() >= 2 // 2 endpoints probed + }, 2*time.Second, 10*time.Millisecond, "first probe cycle should run immediately on Start") +} + +func TestHealthRegistry_ProbeAll_IncrementalUpdate(t *testing.T) { + // Verify that a fast probe's result is visible before a slow probe completes. + slowStarted := make(chan struct{}) + slowRelease := make(chan struct{}) + + r := NewHealthRegistry(2, func(i int) error { + if i == 0 { + // Fast probe: returns immediately. + return nil + } + // Slow probe: blocks until released. + close(slowStarted) + <-slowRelease + return nil + }, nil, RegistryMetrics{}) + r.ConsecutiveThreshold = 1 + + // Run probeAll in a goroutine since the slow probe blocks. + done := make(chan struct{}) + go func() { + r.probeAll() + close(done) + }() + + // Wait for the slow probe to start (meaning the fast probe has already completed). + select { + case <-slowStarted: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for slow probe to start") + } + + // The fast probe (index 0) should already be applied even though the slow + // probe (index 1) is still in flight. + snap := r.HealthSnapshot() + assert.True(t, snap[0].Healthy, "fast probe result should be visible before slow probe completes") + + // Release the slow probe and wait for probeAll to finish. + close(slowRelease) + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for probeAll to finish") + } + + snap = r.HealthSnapshot() + assert.True(t, snap[1].Healthy, "slow probe result should be applied after release") +} + +func TestHealthRegistry_Stop_HaltsGoroutine(t *testing.T) { + probeCount := atomic.Int32{} + + r := NewHealthRegistry(2, func(i int) error { + probeCount.Add(1) + return nil + }, nil, RegistryMetrics{}) + r.HealthCheckInterval = 50 * time.Millisecond + + r.Start() + time.Sleep(150 * time.Millisecond) + r.Stop() + + countAfterStop := probeCount.Load() + time.Sleep(200 * time.Millisecond) + + assert.Equal(t, countAfterStop, probeCount.Load(), "no probes should run after Stop") +} + +func TestHealthRegistry_Run_Integration(t *testing.T) { + probeResults := []error{errors.New("down"), nil} + var results atomic.Value + results.Store(probeResults) + + r := NewHealthRegistry(2, func(i int) error { + return results.Load().([]error)[i] + }, nil, RegistryMetrics{}) + r.HealthCheckInterval = 50 * time.Millisecond + r.ConsecutiveThreshold = 1 + r.PromotionCooldown = 0 + + r.Start() + defer r.Stop() + + // Primary is down, secondary is healthy. Should proactively switch. + require.Eventually(t, func() bool { + return r.Active() == 1 + }, 2*time.Second, 20*time.Millisecond, "should switch to healthy secondary") + + // Bring primary back. + results.Store([]error{nil, nil}) + + // Should promote back to primary. + require.Eventually(t, func() bool { + return r.Active() == 0 + }, 2*time.Second, 20*time.Millisecond, "should promote back to primary") +} diff --git a/consensus/bor/heimdallws/client.go b/consensus/bor/heimdallws/client.go index 2428f289b5..f1d0cb2ec5 100644 --- a/consensus/bor/heimdallws/client.go +++ b/consensus/bor/heimdallws/client.go @@ -3,6 +3,7 @@ package heimdallws import ( "context" "encoding/json" + "errors" "strconv" "sync" "time" @@ -10,27 +11,129 @@ import ( "github.com/gorilla/websocket" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/bor/heimdall" "github.com/ethereum/go-ethereum/consensus/bor/heimdall/milestone" "github.com/ethereum/go-ethereum/log" ) -// HeimdallWSClient represents a websocket client with auto-reconnection. +var ( + ErrNoURLs = errors.New("at least one WS URL required") + ErrNoNonEmptyURLs = errors.New("at least one non-empty WS URL required") +) + +const ( + // defaultReconnectDelay is the backoff between reconnection attempts. + defaultReconnectDelay = 10 * time.Second + + // defaultWSProbeTimeout bounds each individual WS probe dial so a + // firewalled host can't block the health-check goroutine forever. + defaultWSProbeTimeout = 10 * time.Second +) + +// HeimdallWSClient represents a websocket client with auto-reconnection and failover support. type HeimdallWSClient struct { - conn *websocket.Conn - url string // store the URL for reconnection - events chan *milestone.Milestone - done chan struct{} - mu sync.Mutex + conn *websocket.Conn + connEpoch uint64 // incremented on each connection change; detects proactive switches + urls []string // primary at [0], secondary at [1] (if configured) + registry *heimdall.HealthRegistry + events chan *milestone.Milestone + done chan struct{} + mu sync.Mutex + + // Configurable parameters (defaults set in constructor, overridable for testing) + reconnectDelay time.Duration + probeTimeout time.Duration } -// NewHeimdallWSClient creates a new WS client for Heimdall. -func NewHeimdallWSClient(url string) (*HeimdallWSClient, error) { - return &HeimdallWSClient{ - conn: nil, - url: url, - events: make(chan *milestone.Milestone), - done: make(chan struct{}), - }, nil +// NewHeimdallWSClient creates a new WS client for Heimdall with optional failover. +// The first URL is primary; additional URLs are failover candidates in priority order. +func NewHeimdallWSClient(urls ...string) (*HeimdallWSClient, error) { + if len(urls) == 0 { + return nil, ErrNoURLs + } + + var filtered []string + for _, u := range urls { + if u != "" { + filtered = append(filtered, u) + } + } + + if len(filtered) == 0 { + return nil, ErrNoNonEmptyURLs + } + + c := &HeimdallWSClient{ + conn: nil, + urls: filtered, + events: make(chan *milestone.Milestone), + done: make(chan struct{}), + reconnectDelay: defaultReconnectDelay, + probeTimeout: defaultWSProbeTimeout, + } + + c.registry = heimdall.NewHealthRegistry( + len(filtered), + c.probeWSEndpoint, + c.onWSSwitch, + heimdall.RegistryMetrics{ + ProbeAttempts: heimdall.FailoverWSProbeAttempts, + ProbeSuccesses: heimdall.FailoverWSProbeSuccesses, + ProactiveSwitches: heimdall.FailoverWSProactiveSwitches, + ActiveGauge: heimdall.FailoverWSActiveGauge, + HealthyEndpoints: heimdall.FailoverWSHealthyEndpoints, + }, + ) + + return c, nil +} + +// probeWSEndpoint dials a WS endpoint and immediately closes the connection. +func (c *HeimdallWSClient) probeWSEndpoint(i int) error { + c.mu.Lock() + url := c.urls[i] + c.mu.Unlock() + + dialer := websocket.Dialer{ + HandshakeTimeout: c.probeTimeout, + } + + ctx, cancel := context.WithTimeout(context.Background(), c.probeTimeout) + defer cancel() + + testConn, _, err := dialer.DialContext(ctx, url, nil) + if err != nil { + return err + } + + testConn.Close() + + return nil +} + +// onWSSwitch is called by the registry (under registry lock) when the active +// endpoint changes. It bumps the connection epoch, closes the current connection, +// and nils it out. The epoch change lets readMessages distinguish a proactive +// switch from a real network error, avoiding misleading logs and double-closes. +func (c *HeimdallWSClient) onWSSwitch(from, to int) { + c.mu.Lock() + defer c.mu.Unlock() + + c.connEpoch++ + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +// connEpochChanged reports whether the connection epoch has advanced past the +// given snapshot, indicating that a proactive switch (or reconnection) occurred. +func (c *HeimdallWSClient) connEpochChanged(epoch uint64) bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.connEpoch != epoch } // SubscribeMilestoneEvents sends the subscription request and starts processing incoming messages. @@ -40,19 +143,35 @@ func (c *HeimdallWSClient) SubscribeMilestoneEvents(ctx context.Context) <-chan // Start the goroutine to read messages. go c.readMessages(ctx) + // Start the health registry if there are multiple URLs. + if len(c.urls) > 1 { + c.registry.Start() + } + return c.events } -// retry until subscribe +// tryUntilSubscribeMilestoneEvents retries connecting and subscribing until success, +// consulting the health registry to pick the best URL. func (c *HeimdallWSClient) tryUntilSubscribeMilestoneEvents(ctx context.Context) { firstTime := true + for { if !firstTime { - time.Sleep(10 * time.Second) + select { + case <-ctx.Done(): + log.Info("Context cancelled during reconnection") + return + case <-c.done: + log.Info("Client unsubscribed during reconnection") + return + case <-time.After(c.reconnectDelay): + } } + firstTime = false - // Check for context cancellation. + // Check for context cancellation or unsubscribe. select { case <-ctx.Done(): log.Info("Context cancelled during reconnection") @@ -63,16 +182,64 @@ func (c *HeimdallWSClient) tryUntilSubscribeMilestoneEvents(ctx context.Context) default: } - conn, _, err := websocket.DefaultDialer.Dial(c.url, nil) + active := c.registry.Active() + url := c.urls[active] + + conn, _, err := websocket.DefaultDialer.Dial(url, nil) if err != nil { - log.Error("failed to dial websocket on heimdall ws subscription", "err", err) + log.Error("failed to dial websocket on heimdall ws subscription", "url", url, "err", err) + + // Mark endpoint unhealthy in the registry. + c.registry.MarkUnhealthy(active, err) + + // Find the best healthy alternative. + snap := c.registry.HealthSnapshot() + switched := false + + for i := 0; i < len(c.urls); i++ { + if i == active { + continue + } + + if snap[i].Healthy { + c.registry.SetActive(i) + switched = true + + heimdall.FailoverWSSwitchCounter.Inc(1) + + log.Warn("WS URL failed, switching to healthy endpoint", + "from", c.urls[active], "to", c.urls[i]) + + break + } + } + + // If no healthy alternative, try next in round-robin fashion. + if !switched && len(c.urls) > 1 { + next := (active + 1) % len(c.urls) + if next != active { + c.registry.SetActive(next) + + heimdall.FailoverWSSwitchCounter.Inc(1) + + log.Warn("WS URL failed, switching to next endpoint", + "from", c.urls[active], "to", c.urls[next]) + } + } + continue } + + // Close previous connection if any, then set the new one. c.mu.Lock() + if c.conn != nil { + c.conn.Close() + } c.conn = conn - c.mu.Unlock() + c.connEpoch++ - // Build the subscription request. + // Build the subscription request and send it under lock to avoid + // racing with readMessages on c.conn. req := subscriptionRequest{ JSONRPC: "2.0", Method: "subscribe", @@ -80,11 +247,20 @@ func (c *HeimdallWSClient) tryUntilSubscribeMilestoneEvents(ctx context.Context) } req.Params.Query = "tm.event='NewBlock' AND milestone.number>0" - if err := c.conn.WriteJSON(req); err != nil { - log.Error("failed to send subscription request on heimdall ws subscription", "err", err) + err = c.conn.WriteJSON(req) + c.mu.Unlock() + + // Mark outside c.mu to prevent lock-ordering deadlock with + // registry.mu → c.mu (onWSSwitch called from health-check goroutine). + c.registry.MarkSuccess(active) + + if err != nil { + log.Error("failed to send subscription request on heimdall ws subscription", "url", url, "err", err) continue } - log.Info("Successfully connected on heimdall ws subscription") + + log.Info("successfully connected on heimdall ws subscription", "url", url) + return } } @@ -92,6 +268,7 @@ func (c *HeimdallWSClient) tryUntilSubscribeMilestoneEvents(ctx context.Context) // readMessages continuously reads messages from the websocket, handling reconnections if necessary. func (c *HeimdallWSClient) readMessages(ctx context.Context) { defer close(c.events) + for { // Check if the context or unsubscribe signal is set. select { @@ -103,15 +280,38 @@ func (c *HeimdallWSClient) readMessages(ctx context.Context) { // continue to process messages } - if err := c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil { + // Grab local ref and epoch under lock to detect proactive switches. + c.mu.Lock() + conn := c.conn + epoch := c.connEpoch + c.mu.Unlock() + + if conn == nil { + c.tryUntilSubscribeMilestoneEvents(ctx) + continue + } + + if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil { + if c.connEpochChanged(epoch) { + // Proactive switch closed the connection; loop back to pick up the new endpoint. + log.Info("reconnecting due to endpoint switch on heimdall ws subscription") + continue + } + log.Error("failed to set read deadline on heimdall ws subscription", "err", err) c.tryUntilSubscribeMilestoneEvents(ctx) continue } - _, message, err := c.conn.ReadMessage() + _, message, err := conn.ReadMessage() if err != nil { + if c.connEpochChanged(epoch) { + // Proactive switch closed the connection; loop back to pick up the new endpoint. + log.Info("reconnecting due to endpoint switch on heimdall ws subscription") + continue + } + log.Error("connection lost; will attempt to reconnect on heimdall ws subscription", "error", err) c.tryUntilSubscribeMilestoneEvents(ctx) @@ -177,13 +377,19 @@ func (c *HeimdallWSClient) readMessages(ctx context.Context) { // Unsubscribe signals the reader goroutine to stop. func (c *HeimdallWSClient) Unsubscribe(ctx context.Context) error { c.mu.Lock() - defer c.mu.Unlock() select { case <-c.done: // Already unsubscribed. default: close(c.done) } + c.mu.Unlock() + + // Stop the registry outside c.mu to avoid deadlock with probeWSEndpoint, + // which acquires c.mu to read the URL while running under the registry's + // run() goroutine. + c.registry.Stop() + return nil } @@ -191,5 +397,10 @@ func (c *HeimdallWSClient) Unsubscribe(ctx context.Context) error { func (c *HeimdallWSClient) Close() error { c.mu.Lock() defer c.mu.Unlock() + + if c.conn == nil { + return nil + } + return c.conn.Close() } diff --git a/consensus/bor/heimdallws/client_test.go b/consensus/bor/heimdallws/client_test.go new file mode 100644 index 0000000000..a5b2f4330f --- /dev/null +++ b/consensus/bor/heimdallws/client_test.go @@ -0,0 +1,558 @@ +package heimdallws + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// newTestWSServer creates a test WS server that accepts connections and sends a subscription ack. +// If reject is true, the server closes connections immediately. +func newTestWSServer(t *testing.T, reject bool) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if reject { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + + // Read the subscription request. + _, _, err = conn.ReadMessage() + if err != nil { + return + } + + // Send a simple ack (not a milestone, just keeps connection alive). + ack := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 0, + "result": map[string]interface{}{}, + } + + if err := conn.WriteJSON(ack); err != nil { + return + } + + // Keep the connection open until client disconnects. + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } + })) +} + +// newTestWSServerWithMilestone creates a test WS server that sends a milestone event after connection. +func newTestWSServerWithMilestone(t *testing.T) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + + // Read the subscription request. + _, _, err = conn.ReadMessage() + if err != nil { + return + } + + // Send a milestone event. + resp := wsResponse{ + JSONRPC: "2.0", + ID: 0, + Result: wsResult{ + Query: "tm.event='NewBlock' AND milestone.number>0", + Data: wsData{ + Type: "tendermint/event/NewBlock", + Value: wsValue{ + FinalizeBlock: finalizeBlock{ + Events: []wsEvent{ + { + Type: "milestone", + Attributes: []attribute{ + {Key: "proposer", Value: "0x0000000000000000000000000000000000000001"}, + {Key: "hash", Value: "0x0000000000000000000000000000000000000000000000000000000000000002"}, + {Key: "start_block", Value: "100"}, + {Key: "end_block", Value: "200"}, + {Key: "bor_chain_id", Value: "137"}, + {Key: "milestone_id", Value: "test-1"}, + {Key: "timestamp", Value: "1000"}, + {Key: "total_difficulty", Value: "500"}, + }, + }, + }, + }, + }, + }, + }, + } + + data, _ := json.Marshal(resp) + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + return + } + + // Keep connection open. + for { + if _, _, err := conn.ReadMessage(); err != nil { + return + } + } + })) +} + +func wsURL(httpURL string) string { + return "ws" + strings.TrimPrefix(httpURL, "http") +} + +func TestWSClient_ConstructorSingleURL(t *testing.T) { + client, err := NewHeimdallWSClient("ws://localhost:1234") + require.NoError(t, err) + assert.Len(t, client.urls, 1) + assert.Equal(t, "ws://localhost:1234", client.urls[0]) + assert.Equal(t, 0, client.registry.Active()) + snap := client.registry.HealthSnapshot() + assert.Len(t, snap, 1) + assert.True(t, snap[0].Healthy, "primary should start healthy") +} + +func TestWSClient_ConstructorMultipleURLs(t *testing.T) { + client, err := NewHeimdallWSClient("ws://primary:1234", "ws://secondary:5678", "ws://tertiary:9999") + require.NoError(t, err) + assert.Len(t, client.urls, 3) + assert.Equal(t, "ws://primary:1234", client.urls[0]) + assert.Equal(t, "ws://secondary:5678", client.urls[1]) + assert.Equal(t, "ws://tertiary:9999", client.urls[2]) + assert.Equal(t, 0, client.registry.Active()) + snap := client.registry.HealthSnapshot() + assert.Len(t, snap, 3) + assert.True(t, snap[0].Healthy, "primary should start healthy") + assert.False(t, snap[1].Healthy, "secondary should start unhealthy") + assert.False(t, snap[2].Healthy, "tertiary should start unhealthy") +} + +func TestWSClient_ConstructorFiltersEmpty(t *testing.T) { + client, err := NewHeimdallWSClient("ws://primary:1234", "", "ws://tertiary:9999") + require.NoError(t, err) + assert.Len(t, client.urls, 2) + assert.Equal(t, "ws://primary:1234", client.urls[0]) + assert.Equal(t, "ws://tertiary:9999", client.urls[1]) +} + +func TestWSClient_ConstructorNoURLs(t *testing.T) { + _, err := NewHeimdallWSClient() + require.Error(t, err) +} + +func TestWSClient_ConstructorAllEmpty(t *testing.T) { + _, err := NewHeimdallWSClient("", "") + require.Error(t, err) +} + +func TestWSClient_SingleURL_ConnectsSuccessfully(t *testing.T) { + server := newTestWSServerWithMilestone(t) + defer server.Close() + + client, err := NewHeimdallWSClient(wsURL(server.URL)) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := client.SubscribeMilestoneEvents(ctx) + + select { + case m := <-events: + require.NotNil(t, m) + assert.Equal(t, uint64(100), m.StartBlock) + assert.Equal(t, uint64(200), m.EndBlock) + assert.Equal(t, "137", m.BorChainID) + assert.Equal(t, "test-1", m.MilestoneID) + case <-ctx.Done(): + t.Fatal("timed out waiting for milestone event") + } + + require.NoError(t, client.Unsubscribe(ctx)) +} + +func TestWSClient_DualURL_FailoverToSecondary(t *testing.T) { + // Primary always rejects. + primary := newTestWSServer(t, true) + defer primary.Close() + + // Secondary accepts and sends a milestone. + secondary := newTestWSServerWithMilestone(t) + defer secondary.Close() + + client, err := NewHeimdallWSClient(wsURL(primary.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + // Speed up test. + client.reconnectDelay = 100 * time.Millisecond + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := client.SubscribeMilestoneEvents(ctx) + + select { + case m := <-events: + require.NotNil(t, m) + assert.Equal(t, uint64(100), m.StartBlock) + assert.Equal(t, uint64(200), m.EndBlock) + // Verify we switched to secondary. + assert.Equal(t, 1, client.registry.Active()) + case <-ctx.Done(): + t.Fatal("timed out waiting for milestone event via failover") + } + + require.NoError(t, client.Unsubscribe(ctx)) +} + +func TestWSClient_ThreeURL_CascadeToTertiary(t *testing.T) { + // Primary and secondary always reject. + primary := newTestWSServer(t, true) + defer primary.Close() + + secondary := newTestWSServer(t, true) + defer secondary.Close() + + // Tertiary accepts and sends a milestone. + tertiary := newTestWSServerWithMilestone(t) + defer tertiary.Close() + + client, err := NewHeimdallWSClient(wsURL(primary.URL), wsURL(secondary.URL), wsURL(tertiary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 100 * time.Millisecond + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + events := client.SubscribeMilestoneEvents(ctx) + + select { + case m := <-events: + require.NotNil(t, m) + assert.Equal(t, uint64(100), m.StartBlock) + // Verify we ended up on tertiary. + assert.Equal(t, 2, client.registry.Active()) + case <-ctx.Done(): + t.Fatal("timed out waiting for milestone event via cascade") + } + + require.NoError(t, client.Unsubscribe(ctx)) +} + +func TestWSClient_ContextCancellation(t *testing.T) { + // Both URLs reject — client should respect context cancellation. + primary := newTestWSServer(t, true) + defer primary.Close() + + secondary := newTestWSServer(t, true) + defer secondary.Close() + + client, err := NewHeimdallWSClient(wsURL(primary.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 100 * time.Millisecond + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel after a short delay. + go func() { + time.Sleep(300 * time.Millisecond) + cancel() + }() + + // tryUntilSubscribeMilestoneEvents should return without blocking forever. + client.tryUntilSubscribeMilestoneEvents(ctx) + + // Verify context was cancelled. + assert.Error(t, ctx.Err()) +} + +func TestWSClient_DualURL_ProbeBackToPrimary(t *testing.T) { + fix := setupWSFailover(t, 100*time.Millisecond, 1, 0) + defer fix.cleanup(t) + + // Wait for background health registry to promote back to primary. + require.Eventually(t, func() bool { + return fix.client.registry.Active() == 0 + }, 5*time.Second, 50*time.Millisecond, "health registry should promote back to primary") +} + +func TestWSClient_DualURL_NoWrapOnLastURLFails(t *testing.T) { + // Both URLs reject. The client should handle correctly when on last URL. + primary := newTestWSServer(t, true) + defer primary.Close() + + secondary := newTestWSServer(t, true) + defer secondary.Close() + + client, err := NewHeimdallWSClient(wsURL(primary.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 10 * time.Millisecond + client.registry.HealthCheckInterval = 1 * time.Hour // prevent health-check from interfering + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + // Pre-set to secondary as if a prior failover already happened. + client.registry.SetActive(1) + + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + client.tryUntilSubscribeMilestoneEvents(ctx) + + // Should have moved off secondary since it fails. + active := client.registry.Active() + + // May have wrapped to primary (index 0) since secondary fails. + _ = active // either index is acceptable; the important thing is it didn't hang. +} + +func TestWSClient_DualURL_PrimaryRecovery(t *testing.T) { + // Start with primary down, then bring it up. + + // Primary starts rejecting. + primaryReject := newTestWSServer(t, true) + + // Secondary accepts with milestone. + secondary := newTestWSServerWithMilestone(t) + defer secondary.Close() + + client, err := NewHeimdallWSClient(wsURL(primaryReject.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 100 * time.Millisecond + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := client.SubscribeMilestoneEvents(ctx) + + // Should failover to secondary. + select { + case m := <-events: + require.NotNil(t, m) + assert.Equal(t, 1, client.registry.Active()) + assert.Equal(t, uint64(100), m.StartBlock) + case <-ctx.Done(): + t.Fatal("timed out waiting for failover") + } + + // Close the rejecting primary. + primaryReject.Close() + + require.NoError(t, client.Unsubscribe(ctx)) +} + +func TestWSClient_HealthRegistryRespectsUnsubscribe(t *testing.T) { + // Verify that the health registry goroutine stops when done channel is closed. + primary := newTestWSServer(t, true) + defer primary.Close() + + secondary := newTestWSServerWithMilestone(t) + defer secondary.Close() + + client, err := NewHeimdallWSClient(wsURL(primary.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 100 * time.Millisecond + client.registry.HealthCheckInterval = 50 * time.Millisecond + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := client.SubscribeMilestoneEvents(ctx) + + // Wait for failover to secondary. + select { + case m := <-events: + require.NotNil(t, m) + case <-ctx.Done(): + t.Fatal("timed out waiting for failover") + } + + // Unsubscribe should stop the health registry goroutine. + require.NoError(t, client.Unsubscribe(ctx)) + + // Give a moment for the goroutine to stop and verify no panics. + time.Sleep(200 * time.Millisecond) +} + +// wsFailoverFixture holds the shared state for WS failover tests that start with +// a rejecting primary, failover to a milestone-serving secondary, then swap in a +// good primary to test promotion behavior. +type wsFailoverFixture struct { + client *HeimdallWSClient + ctx context.Context + cancel context.CancelFunc +} + +// setupWSFailover creates a rejecting primary and accepting secondary, subscribes +// to milestone events, waits for failover to secondary, then replaces the primary +// with an accepting server. The caller can then assert promotion behavior. +func setupWSFailover(t *testing.T, healthInterval time.Duration, threshold int, cooldown time.Duration) *wsFailoverFixture { + t.Helper() + + primaryReject := newTestWSServer(t, true) + t.Cleanup(primaryReject.Close) + + secondary := newTestWSServerWithMilestone(t) + t.Cleanup(secondary.Close) + + client, err := NewHeimdallWSClient(wsURL(primaryReject.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 100 * time.Millisecond + client.registry.HealthCheckInterval = healthInterval + client.registry.ConsecutiveThreshold = threshold + client.registry.PromotionCooldown = cooldown + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + events := client.SubscribeMilestoneEvents(ctx) + + // Wait for failover to secondary. + select { + case m := <-events: + require.NotNil(t, m) + case <-ctx.Done(): + t.Fatal("timed out waiting for failover") + } + + // Replace rejecting primary with accepting one. + primaryReject.Close() + + primaryGood := newTestWSServer(t, false) + t.Cleanup(primaryGood.Close) + + client.mu.Lock() + client.urls[0] = wsURL(primaryGood.URL) + client.mu.Unlock() + + return &wsFailoverFixture{client: client, ctx: ctx, cancel: cancel} +} + +func (f *wsFailoverFixture) cleanup(t *testing.T) { + t.Helper() + + defer f.cancel() + require.NoError(t, f.client.Unsubscribe(f.ctx)) +} + +// --- New health registry tests --- + +func TestWSClient_Registry_ConsecutiveThreshold(t *testing.T) { + fix := setupWSFailover(t, 50*time.Millisecond, 3, 0) + defer fix.cleanup(t) + + // Should eventually promote after 3 consecutive successes. + require.Eventually(t, func() bool { + return fix.client.registry.Active() == 0 + }, 5*time.Second, 50*time.Millisecond, "should promote after consecutive threshold met") +} + +func TestWSClient_Registry_PromotionCooldown(t *testing.T) { + fix := setupWSFailover(t, 50*time.Millisecond, 1, 500*time.Millisecond) + defer fix.cleanup(t) + + // Should not promote immediately (cooldown not met). + time.Sleep(150 * time.Millisecond) + assert.Equal(t, 1, fix.client.registry.Active(), "should not promote before cooldown") + + // Wait for cooldown to pass. + require.Eventually(t, func() bool { + return fix.client.registry.Active() == 0 + }, 3*time.Second, 50*time.Millisecond, "should promote after cooldown passes") +} + +func TestWSClient_ProactiveSwitchSetsConnNil(t *testing.T) { + // Verify that onWSSwitch nils out the connection and bumps the epoch, + // so readMessages detects the switch via epoch change rather than + // seeing a stale non-nil closed conn. + primary := newTestWSServerWithMilestone(t) + defer primary.Close() + + secondary := newTestWSServerWithMilestone(t) + defer secondary.Close() + + client, err := NewHeimdallWSClient(wsURL(primary.URL), wsURL(secondary.URL)) + require.NoError(t, err) + + client.reconnectDelay = 100 * time.Millisecond + client.registry.HealthCheckInterval = 1 * time.Hour // manual control + client.registry.ConsecutiveThreshold = 1 + client.registry.PromotionCooldown = 0 + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := client.SubscribeMilestoneEvents(ctx) + + // Receive milestone from primary. + select { + case m := <-events: + require.NotNil(t, m) + assert.Equal(t, 0, client.registry.Active()) + case <-ctx.Done(): + t.Fatal("timed out waiting for milestone from primary") + } + + // Capture epoch before switch. + client.mu.Lock() + epochBefore := client.connEpoch + client.mu.Unlock() + + // Simulate a proactive switch by calling onWSSwitch directly. + client.onWSSwitch(0, 1) + + // Verify conn is nil and epoch advanced. + client.mu.Lock() + assert.Nil(t, client.conn, "onWSSwitch should nil out the connection") + assert.Greater(t, client.connEpoch, epochBefore, "onWSSwitch should bump epoch") + client.mu.Unlock() + + // readMessages should detect the nil conn and reconnect. + // Set active to secondary so reconnection goes there. + client.registry.SetActive(1) + + require.NoError(t, client.Unsubscribe(ctx)) +} diff --git a/eth/ethconfig/config.go b/eth/ethconfig/config.go index 27b4a25fed..64d7361f5e 100644 --- a/eth/ethconfig/config.go +++ b/eth/ethconfig/config.go @@ -18,7 +18,9 @@ package ethconfig import ( + "fmt" "math/big" + "strings" "time" "github.com/ethereum/go-ethereum/common" @@ -45,6 +47,25 @@ import ( "github.com/ethereum/go-ethereum/params" ) +// parseURLs splits a comma-separated URL string into a trimmed, non-empty slice. +func parseURLs(s string) []string { + if s == "" { + return nil + } + + parts := strings.Split(s, ",") + + var out []string + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + + return out +} + // FullNodeGPO contains default gasprice oracle settings for full node. var FullNodeGPO = gasprice.Config{ Blocks: 20, @@ -210,7 +231,7 @@ type Config struct { // position in eth_getLogs filter criteria (0 = no cap) RPCLogQueryLimit int - // URL to connect to Heimdall node + // URL to connect to Heimdall node (comma-separated for failover: "url1,url2,url3") HeimdallURL string // timeout in heimdall requests @@ -219,10 +240,10 @@ type Config struct { // No heimdall service WithoutHeimdall bool - // Address to connect to Heimdall gRPC server + // Address to connect to Heimdall gRPC server (comma-separated for failover: "addr1,addr2") HeimdallgRPCAddress string - // Address to connect to Heimdall WS subscription server + // Address to connect to Heimdall WS subscription server (comma-separated for failover: "addr1,addr2") HeimdallWSAddress string // Run heimdall service as a child process @@ -334,33 +355,70 @@ func CreateConsensusEngine(chainConfig *params.ChainConfig, ethConfig *Config, d // TODO: Running heimdall from bor is not tested yet. // heimdallClient = heimdallapp.NewHeimdallAppClient() panic("Running heimdall from bor is not implemented yet. Please use heimdall gRPC or HTTP client instead.") - } else if ethConfig.HeimdallgRPCAddress != "" { - grpcClient, err := heimdallgrpc.NewHeimdallGRPCClient( - ethConfig.HeimdallgRPCAddress, - ethConfig.HeimdallURL, - ethConfig.HeimdallTimeout, - ) - if err != nil { - log.Error("Failed to initialize Heimdall gRPC client; falling back to HTTP Heimdall client", - "heimdall_grpc", ethConfig.HeimdallgRPCAddress, - "heimdall_http", ethConfig.HeimdallURL, - "err", err, - ) + } else { + httpURLs := parseURLs(ethConfig.HeimdallURL) + grpcAddrs := parseURLs(ethConfig.HeimdallgRPCAddress) + + // Build one client per endpoint. + // gRPC takes priority where configured; falls back to HTTP. + var heimdallClients []heimdall.Endpoint + + n := max(len(httpURLs), len(grpcAddrs)) + for i := 0; i < n; i++ { + if i < len(grpcAddrs) && grpcAddrs[i] != "" { + var httpURL string + if len(httpURLs) > 0 { + httpURL = httpURLs[min(i, len(httpURLs)-1)] + } + + grpcClient, err := heimdallgrpc.NewHeimdallGRPCClient(grpcAddrs[i], httpURL, ethConfig.HeimdallTimeout) + if err != nil { + log.Error("Failed to initialize Heimdall gRPC client; falling back to HTTP", + "index", i, "grpc", grpcAddrs[i], "err", err) + + if i < len(httpURLs) { + heimdallClients = append(heimdallClients, heimdall.NewHeimdallClient(httpURLs[i], ethConfig.HeimdallTimeout)) + } + + continue + } + + heimdallClients = append(heimdallClients, grpcClient) + } else if i < len(httpURLs) { + heimdallClients = append(heimdallClients, heimdall.NewHeimdallClient(httpURLs[i], ethConfig.HeimdallTimeout)) + } + } + + if len(heimdallClients) == 0 { heimdallClient = heimdall.NewHeimdallClient(ethConfig.HeimdallURL, ethConfig.HeimdallTimeout) + } else if len(heimdallClients) == 1 { + heimdallClient = heimdallClients[0] } else { - heimdallClient = grpcClient + multiClient, err := heimdall.NewMultiHeimdallClient(heimdallClients...) + if err != nil { + return nil, fmt.Errorf("failed to create heimdall failover client: %w", err) + } + + heimdallClient = multiClient + log.Info("Heimdall failover enabled with multiple endpoints", "endpoints", len(heimdallClients)) } - } else { - heimdallClient = heimdall.NewHeimdallClient(ethConfig.HeimdallURL, ethConfig.HeimdallTimeout) } + // WS client + wsAddrs := parseURLs(ethConfig.HeimdallWSAddress) + var heimdallWSClient bor.IHeimdallWSClient var err error - if ethConfig.HeimdallWSAddress != "" { - heimdallWSClient, err = heimdallws.NewHeimdallWSClient(ethConfig.HeimdallWSAddress) + + if len(wsAddrs) > 0 { + heimdallWSClient, err = heimdallws.NewHeimdallWSClient(wsAddrs...) if err != nil { return nil, err } + + if len(wsAddrs) > 1 { + log.Info("Heimdall WS failover enabled with multiple endpoints", "endpoints", len(wsAddrs)) + } } return bor.New(chainConfig, db, blockchainAPI, spanner, heimdallClient, heimdallWSClient, genesisContractsClient, false, ethConfig.Miner.BlockTime), nil diff --git a/eth/ethconfig/config_test.go b/eth/ethconfig/config_test.go index b85431d12d..302a570834 100644 --- a/eth/ethconfig/config_test.go +++ b/eth/ethconfig/config_test.go @@ -10,10 +10,13 @@ import ( ctypes "github.com/cometbft/cometbft/rpc/core/types" "github.com/ethereum/go-ethereum/consensus/bor" "github.com/ethereum/go-ethereum/consensus/bor/clerk" + "github.com/ethereum/go-ethereum/consensus/bor/heimdall" "github.com/ethereum/go-ethereum/consensus/bor/heimdall/checkpoint" "github.com/ethereum/go-ethereum/consensus/bor/heimdall/milestone" + "github.com/ethereum/go-ethereum/consensus/bor/heimdallws" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/params" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -88,6 +91,41 @@ func TestCreateConsensusEngine_OverrideHeimdallClient(t *testing.T) { require.True(t, ok, "Expected Bor consensus engine") } +func TestCreateConsensusEngine_CommaSeparatedHeimdallURL(t *testing.T) { + t.Parallel() + ethConfig := &Config{ + HeimdallURL: "http://primary:1317,http://secondary:1317", + } + + engine, err := CreateConsensusEngine(newTestBorChainConfig(), ethConfig, rawdb.NewMemoryDatabase(), nil) + require.NoError(t, err) + defer engine.Close() + + borEngine, ok := engine.(*bor.Bor) + require.True(t, ok, "Expected Bor consensus engine") + + _, ok = borEngine.HeimdallClient.(*heimdall.MultiHeimdallClient) + require.True(t, ok, "Expected HeimdallClient to be wrapped in MultiHeimdallClient") +} + +func TestCreateConsensusEngine_SingleHeimdallURL(t *testing.T) { + t.Parallel() + ethConfig := &Config{ + HeimdallURL: "http://primary:1317", + } + + engine, err := CreateConsensusEngine(newTestBorChainConfig(), ethConfig, rawdb.NewMemoryDatabase(), nil) + require.NoError(t, err) + defer engine.Close() + + borEngine, ok := engine.(*bor.Bor) + require.True(t, ok, "Expected Bor consensus engine") + + // Single URL should NOT produce a MultiHeimdallClient + _, ok = borEngine.HeimdallClient.(*heimdall.MultiHeimdallClient) + require.False(t, ok, "Expected no MultiHeimdallClient for single URL") +} + func TestCreateConsensusEngine_WithoutHeimdall(t *testing.T) { t.Parallel() ethConfig := &Config{WithoutHeimdall: true} @@ -99,3 +137,152 @@ func TestCreateConsensusEngine_WithoutHeimdall(t *testing.T) { _, ok := engine.(*bor.Bor) require.True(t, ok, "Expected Bor consensus engine") } + +func TestCreateConsensusEngine_CommaSeparatedGRPC(t *testing.T) { + t.Parallel() + ethConfig := &Config{ + HeimdallURL: "http://primary:1317,http://secondary:1317", + HeimdallgRPCAddress: "localhost:50051,localhost:50052", + } + + engine, err := CreateConsensusEngine(newTestBorChainConfig(), ethConfig, rawdb.NewMemoryDatabase(), nil) + require.NoError(t, err) + defer engine.Close() + + borEngine, ok := engine.(*bor.Bor) + require.True(t, ok, "Expected Bor consensus engine") + + _, ok = borEngine.HeimdallClient.(*heimdall.MultiHeimdallClient) + require.True(t, ok, "Expected MultiHeimdallClient with multiple gRPC endpoints") +} + +func TestCreateConsensusEngine_GRPCInitFailsFallsBackToHTTP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + heimdallURL string + grpcAddress string + expectFailover bool + }{ + { + // gRPC uses unsupported scheme → NewHeimdallGRPCClient fails. + // Fallback appends HTTP client for httpURLs[0]; httpURLs[1] also + // gets an HTTP client via the else-if branch → 2 clients → failover. + name: "with HTTP URL available", + heimdallURL: "http://a:1317,http://b:1317", + grpcAddress: "ftp://invalid:50051", + expectFailover: true, + }, + { + // gRPC[0] succeeds (localhost is allowed), gRPC[1] fails (bad scheme). + // i=1 >= len(httpURLs)=1 so no HTTP fallback is added → only 1 client. + name: "without HTTP URL at that index", + heimdallURL: "http://a:1317", + grpcAddress: "localhost:50051,ftp://invalid:50052", + expectFailover: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ethConfig := &Config{ + HeimdallURL: tt.heimdallURL, + HeimdallgRPCAddress: tt.grpcAddress, + } + + engine, err := CreateConsensusEngine(newTestBorChainConfig(), ethConfig, rawdb.NewMemoryDatabase(), nil) + require.NoError(t, err) + defer engine.Close() + + borEngine, ok := engine.(*bor.Bor) + require.True(t, ok, "Expected Bor consensus engine") + + _, ok = borEngine.HeimdallClient.(*heimdall.MultiHeimdallClient) + require.Equal(t, tt.expectFailover, ok) + }) + } +} + +func TestCreateConsensusEngine_WSAddress(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + addr string + }{ + {"comma-separated", "ws://localhost:26657,ws://secondary:26657"}, + {"primary only", "ws://localhost:26657"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ethConfig := &Config{ + OverrideHeimdallClient: &mockHeimdallClient{}, + HeimdallWSAddress: tt.addr, + } + + engine, err := CreateConsensusEngine(newTestBorChainConfig(), ethConfig, rawdb.NewMemoryDatabase(), nil) + require.NoError(t, err) + defer engine.Close() + + borEngine, ok := engine.(*bor.Bor) + require.True(t, ok, "Expected Bor consensus engine") + + require.NotNil(t, borEngine.HeimdallWSClient, "Expected non-nil HeimdallWSClient") + + _, ok = borEngine.HeimdallWSClient.(*heimdallws.HeimdallWSClient) + require.True(t, ok, "Expected HeimdallWSClient type") + }) + } +} + +func TestCreateConsensusEngine_NoWSAddress(t *testing.T) { + t.Parallel() + + ethConfig := &Config{ + OverrideHeimdallClient: &mockHeimdallClient{}, + // No HeimdallWSAddress set + } + + engine, err := CreateConsensusEngine(newTestBorChainConfig(), ethConfig, rawdb.NewMemoryDatabase(), nil) + require.NoError(t, err) + defer engine.Close() + + borEngine, ok := engine.(*bor.Bor) + require.True(t, ok, "Expected Bor consensus engine") + + require.Nil(t, borEngine.HeimdallWSClient, "Expected nil HeimdallWSClient when no WS address configured") +} + +func TestParseURLs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected []string + }{ + {"empty string", "", nil}, + {"single URL", "http://localhost:1317", []string{"http://localhost:1317"}}, + {"two URLs", "http://a:1317,http://b:1317", []string{"http://a:1317", "http://b:1317"}}, + {"three URLs", "http://a:1317,http://b:1317,http://c:1317", []string{"http://a:1317", "http://b:1317", "http://c:1317"}}, + {"whitespace trimmed", " http://a:1317 , http://b:1317 ", []string{"http://a:1317", "http://b:1317"}}, + {"trailing comma", "http://a:1317,", []string{"http://a:1317"}}, + {"leading comma", ",http://a:1317", []string{"http://a:1317"}}, + {"empty entries filtered", "http://a:1317,,http://b:1317", []string{"http://a:1317", "http://b:1317"}}, + {"only commas", ",,,", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := parseURLs(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/cli/server/config.go b/internal/cli/server/config.go index 262cada9ac..a5f6eed5fc 100644 --- a/internal/cli/server/config.go +++ b/internal/cli/server/config.go @@ -312,7 +312,7 @@ type P2PDiscovery struct { } type HeimdallConfig struct { - // URL is the url of the heimdall server + // URL is the url of the heimdall server (comma-separated for failover: "url1,url2,url3") URL string `hcl:"url,optional" toml:"url,optional"` Timeout time.Duration `hcl:"timeout,optional" toml:"timeout,optional"` @@ -320,10 +320,10 @@ type HeimdallConfig struct { // Without is used to disable remote heimdall during testing Without bool `hcl:"bor.without,optional" toml:"bor.without,optional"` - // GRPCAddress is the address of the heimdall grpc server + // GRPCAddress is the address of the heimdall grpc server (comma-separated for failover: "addr1,addr2") GRPCAddress string `hcl:"grpc-address,optional" toml:"grpc-address,optional"` - // WSAddress is the address of the heimdall ws subscription server + // WSAddress is the address of the heimdall ws subscription server (comma-separated for failover: "addr1,addr2") WSAddress string `hcl:"ws-address,optional" toml:"ws-address,optional"` // RunHeimdall is used to run heimdall as a child process diff --git a/internal/cli/server/flags.go b/internal/cli/server/flags.go index b7ccb7bc5a..6d4891d67f 100644 --- a/internal/cli/server/flags.go +++ b/internal/cli/server/flags.go @@ -175,7 +175,7 @@ func (c *Command) Flags(config *Config) *flagset.Flagset { // heimdall f.StringFlag(&flagset.StringFlag{ Name: "bor.heimdall", - Usage: "URL of Heimdall service", + Usage: "URL of Heimdall service (comma-separated for failover: \"url1,url2\")", Value: &c.cliConfig.Heimdall.URL, Default: c.cliConfig.Heimdall.URL, }) @@ -199,13 +199,13 @@ func (c *Command) Flags(config *Config) *flagset.Flagset { }) f.StringFlag(&flagset.StringFlag{ Name: "bor.heimdallgRPC", - Usage: "Address of Heimdall gRPC service", + Usage: "Address of Heimdall gRPC service (comma-separated for failover: \"addr1,addr2\")", Value: &c.cliConfig.Heimdall.GRPCAddress, Default: c.cliConfig.Heimdall.GRPCAddress, }) f.StringFlag(&flagset.StringFlag{ Name: "bor.heimdallWS", - Usage: "Address of Heimdall ws subscription service", + Usage: "Address of Heimdall WS subscription service (comma-separated for failover: \"addr1,addr2\")", Value: &c.cliConfig.Heimdall.WSAddress, Default: c.cliConfig.Heimdall.WSAddress, }) diff --git a/internal/cli/server/testdata/default.toml b/internal/cli/server/testdata/default.toml index d3b00e5fcc..c3213e2633 100644 --- a/internal/cli/server/testdata/default.toml +++ b/internal/cli/server/testdata/default.toml @@ -52,6 +52,7 @@ devfakeauthor = false url = "http://localhost:1317" "bor.without" = false grpc-address = "" + ws-address = "" "bor.runheimdall" = false "bor.runheimdallargs" = "" "bor.useheimdallapp" = false