@@ -7,6 +7,8 @@ package shared
77import (
88 "context"
99 "errors"
10+ "fmt"
11+ "io"
1012 "net"
1113 "net/http"
1214 "net/url"
@@ -32,16 +34,19 @@ import (
3234// preserving path and query.
3335type FailoverRoundTripper struct {
3436 cfg *Configuration
37+ opts *FailoverOptions
3538 base http.RoundTripper
3639}
3740
38- // NewFailoverRoundTripper creates a new FailoverRoundTripper with the given configuration and base RoundTripper.
39- func NewFailoverRoundTripper(cfg *Configuration, base http.RoundTripper) http.RoundTripper {
41+ // NewFailoverRoundTripper creates a new FailoverRoundTripper.
42+ // If opts is nil, it will fall back to cfg.Failover.
43+ func NewFailoverRoundTripper(cfg *Configuration, opts *FailoverOptions, base http.RoundTripper) http.RoundTripper {
4044 if base == nil {
4145 base = http.DefaultTransport
4246 }
4347 return &FailoverRoundTripper{
4448 cfg: cfg,
49+ opts: opts,
4550 base: base,
4651 }
4752}
@@ -51,60 +56,86 @@ func (t *FailoverRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
5156 if req == nil {
5257 return nil, errors.New("nil request")
5358 }
54- if t == nil || t.cfg == nil {
55- if t != nil && t.base != nil {
56- return t.base.RoundTrip(req)
57- }
58- return http.DefaultTransport.RoundTrip(req)
59+ if t == nil {
60+ return nil, errors.New("nil FailoverRoundTripper")
61+ }
62+ if t.base == nil {
63+ // Be resilient if instantiated without constructor.
64+ t.base = http.DefaultTransport
65+ }
66+ if t.cfg == nil {
67+ // No config => behave like the base transport.
68+ return t.base.RoundTrip(req)
5969 }
6070
61- strategyName := strings.TrimSpace(strings.ToLower(string(t.cfg.FailoverStrategy)))
62- servers := len(t.cfg.Servers)
63- if strategyName == "" || strategyName == string(FailoverNone) || servers <= 1 {
71+ fo := t.opts
72+ if fo == nil {
73+ fo = t.cfg.Failover
74+ }
75+ if fo == nil {
6476 return t.base.RoundTrip(req)
6577 }
6678
67- if strategyName != strings.ToLower(string(FailoverRoundRobin)) {
79+ servers := t.cfg.Servers
80+ order := serverOrderFor(fo.Strategy, len(servers))
81+ if order == nil {
82+ // Unknown or disabled strategy => pass through.
6883 return t.base.RoundTrip(req)
6984 }
7085
7186 // Check if method is allowed for failover retries.
72- if !isRetryableMethod(t.cfg , req.Method) {
87+ if !isRetryableMethod(fo , req.Method) {
7388 return t.base.RoundTrip(req)
7489 }
7590
7691 var lastErr error
77- for i := range servers {
78- // Always start from the first server in the list.
79- serverURL := t.cfg.Servers[i].URL
92+ for attempt := range len(servers) {
93+ serverURL := servers[order(attempt)].URL
8094
8195 targetURL, err := url.Parse(serverURL)
8296 if err != nil {
83- lastErr = err
84- continue
97+ return nil, fmt.Errorf("invalid server URL at Servers[%d]=%q: %w", order(attempt), serverURL, err)
8598 }
8699
87100 attemptReq, err := cloneRequestForRetry(req)
88101 if err != nil {
89102 return nil, err
90103 }
91104
92- // Update both URL and the Host header field.
93105 attemptReq.URL.Scheme = targetURL.Scheme
94106 attemptReq.URL.Host = targetURL.Host
95107 attemptReq.Host = targetURL.Host
96108
109+ if SdkLogLevel.Satisfies(Debug) {
110+ SdkLogger.Printf("[Failover] attempt=%d method=%s url=%s", attempt+1, attemptReq.Method, attemptReq.URL.String())
111+ }
112+
97113 resp, err := t.base.RoundTrip(attemptReq)
98114 if err == nil {
115+ if shouldFailoverOnStatus(fo, resp.StatusCode) {
116+ if SdkLogLevel.Satisfies(Debug) {
117+ SdkLogger.Printf("[Failover] status=%d triggers failover to next server", resp.StatusCode)
118+ }
119+ // Drain/close body to allow connection reuse.
120+ if resp.Body != nil {
121+ _, _ = io.Copy(io.Discard, resp.Body)
122+ _ = resp.Body.Close()
123+ }
124+ lastErr = fmt.Errorf("failover status: %s", resp.Status)
125+ continue
126+ }
99127 return resp, nil
100128 }
101129
102130 lastErr = err
103- if !isNetworkErrorRT(attemptReq.Context(), err, t.cfg.RetryOnTimeout) {
131+ retryable := isNetworkErrorRT(attemptReq.Context(), err, fo.RetryOnTimeout)
132+ if !retryable {
104133 return nil, err
105134 }
135+ if SdkLogLevel.Satisfies(Debug) {
136+ SdkLogger.Printf("[Failover] network error: %v; trying next server", err)
137+ }
106138
107- // Ensure we don't spin too hot in case of immediate failures.
108139 tinyBackoff(attemptReq.Context())
109140 }
110141
@@ -126,18 +157,17 @@ func cloneRequestForRetry(req *http.Request) (*http.Request, error) {
126157 return clone, nil
127158}
128159
129- func isRetryableMethod(cfg *Configuration , method string) bool {
160+ func isRetryableMethod(fo *FailoverOptions , method string) bool {
130161 m := strings.ToUpper(strings.TrimSpace(method))
131- if cfg == nil {
162+ if fo == nil {
132163 return defaultRetryableMethods[m]
133164 }
134165
135- // If not configured, use defaults.
136- if len(cfg.RetryableMethods) == 0 {
166+ if len(fo.RetryableMethods) == 0 {
137167 return defaultRetryableMethods[m]
138168 }
139169
140- for _, v := range cfg .RetryableMethods {
170+ for _, v := range fo .RetryableMethods {
141171 if strings.ToUpper(strings.TrimSpace(v)) == m {
142172 return true
143173 }
@@ -208,3 +238,35 @@ func tinyBackoff(ctx context.Context) {
208238 case <-ctx.Done():
209239 }
210240}
241+
242+ func shouldFailoverOnStatus(fo *FailoverOptions, statusCode int) bool {
243+ if fo == nil || len(fo.FailoverOnStatusCodes) == 0 {
244+ return false
245+ }
246+ for _, sc := range fo.FailoverOnStatusCodes {
247+ if sc == statusCode {
248+ return true
249+ }
250+ }
251+ return false
252+ }
253+
254+ // serverOrder maps an attempt index (0, 1, 2, …) to a server index.
255+ // Different strategies produce different orderings.
256+ type serverOrder func(attempt int) int
257+
258+ // serverOrderFor returns a serverOrder for the given strategy, or nil when
259+ // failover should not be applied (unknown/disabled strategy, ≤1 server).
260+ func serverOrderFor(strategy FailoverStrategy, numServers int) serverOrder {
261+ s := strings.TrimSpace(strings.ToLower(string(strategy)))
262+ if s == "" || s == string(FailoverNone) || numServers <= 1 {
263+ return nil
264+ }
265+ switch s {
266+ case strings.ToLower(string(FailoverRoundRobin)):
267+ // Sequential: 0, 1, 2, …
268+ return func(attempt int) int { return attempt % numServers }
269+ default:
270+ return nil
271+ }
272+ }
0 commit comments