@@ -126,28 +126,32 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver.
126126 timeout = c .defaultTimeout
127127 }
128128
129- serverCount := len (c .servers )
130- var specificServer driver.Connection
129+ var server driver.Connection
130+ var serverCount int
131+ var durationPerRequest time.Duration
132+
131133 if v := ctx .Value (keyEndpoint ); v != nil {
132134 if endpoint , ok := v .(string ); ok {
133135 // Override pool to only specific server if it is found
134136 if s , ok := c .getSpecificServer (endpoint ); ok {
137+ server = s
138+ durationPerRequest = timeout
135139 serverCount = 1
136- specificServer = s
137140 }
138141 }
139142 }
140143
141- timeoutDivider := math .Max (1.0 , math .Min (3.0 , float64 (serverCount )))
142- attempt := 1
143- s := specificServer
144- if s == nil {
145- s = c .getCurrentServer ()
144+ if server == nil {
145+ server , serverCount = c .getCurrentServer ()
146+ timeoutDivider := math .Max (1.0 , math .Min (3.0 , float64 (serverCount )))
147+ durationPerRequest = time .Duration (float64 (timeout ) / timeoutDivider )
146148 }
149+
150+ attempt := 1
147151 for {
148152 // Send request to specific endpoint with a 1/3 timeout (so we get 3 attempts)
149- serverCtx , cancel := context .WithTimeout (ctx , time . Duration ( float64 ( timeout ) / timeoutDivider ) )
150- resp , err := s .Do (serverCtx , req )
153+ serverCtx , cancel := context .WithTimeout (ctx , durationPerRequest )
154+ resp , err := server .Do (serverCtx , req )
151155 cancel ()
152156
153157 isNoLeaderResponse := false
@@ -162,8 +166,8 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver.
162166 err = aerr
163167 }
164168 }
165-
166169 }
170+
167171 if ! isNoLeaderResponse || ! followLeaderRedirect {
168172 if err == nil {
169173 // We're done
@@ -189,15 +193,13 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver.
189193
190194 // Failed, try next server
191195 attempt ++
192- if specificServer != nil {
196+ if attempt > serverCount {
193197 // A specific server was specified, no failover.
194- return nil , driver .WithStack (err )
195- }
196- if attempt > len (c .servers ) {
198+ // or
197199 // We've tried all servers. Giving up.
198200 return nil , driver .WithStack (err )
199201 }
200- s = c .getNextServer ()
202+ server = c .getNextServer ()
201203 }
202204}
203205
@@ -321,11 +323,11 @@ func (c *clusterConnection) Protocols() driver.ProtocolSet {
321323 return result
322324}
323325
324- // getCurrentServer returns the currently used server.
325- func (c * clusterConnection ) getCurrentServer () driver.Connection {
326+ // getCurrentServer returns the currently used server and number of servers .
327+ func (c * clusterConnection ) getCurrentServer () ( driver.Connection , int ) {
326328 c .mutex .RLock ()
327329 defer c .mutex .RUnlock ()
328- return c .servers [c .current ]
330+ return c .servers [c .current ], len ( c . servers )
329331}
330332
331333// getSpecificServer returns the server with the given endpoint.
0 commit comments