Skip to content

Commit e63cef3

Browse files
authored
test (#312)
1 parent 2aac9e3 commit e63cef3

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

cluster/cluster.go

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,32 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver.
126126
timeout = c.defaultTimeout
127127
}
128128

129-
serverCount := len(c.servers)
130-
var specificServer driver.Connection
129+
var server driver.Connection
130+
var serverCount int
131+
var durationPerRequest time.Duration
132+
131133
if v := ctx.Value(keyEndpoint); v != nil {
132134
if endpoint, ok := v.(string); ok {
133135
// Override pool to only specific server if it is found
134136
if s, ok := c.getSpecificServer(endpoint); ok {
137+
server = s
138+
durationPerRequest = timeout
135139
serverCount = 1
136-
specificServer = s
137140
}
138141
}
139142
}
140143

141-
timeoutDivider := math.Max(1.0, math.Min(3.0, float64(serverCount)))
142-
attempt := 1
143-
s := specificServer
144-
if s == nil {
145-
s = c.getCurrentServer()
144+
if server == nil {
145+
server, serverCount = c.getCurrentServer()
146+
timeoutDivider := math.Max(1.0, math.Min(3.0, float64(serverCount)))
147+
durationPerRequest = time.Duration(float64(timeout) / timeoutDivider)
146148
}
149+
150+
attempt := 1
147151
for {
148152
// Send request to specific endpoint with a 1/3 timeout (so we get 3 attempts)
149-
serverCtx, cancel := context.WithTimeout(ctx, time.Duration(float64(timeout)/timeoutDivider))
150-
resp, err := s.Do(serverCtx, req)
153+
serverCtx, cancel := context.WithTimeout(ctx, durationPerRequest)
154+
resp, err := server.Do(serverCtx, req)
151155
cancel()
152156

153157
isNoLeaderResponse := false
@@ -162,8 +166,8 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver.
162166
err = aerr
163167
}
164168
}
165-
166169
}
170+
167171
if !isNoLeaderResponse || !followLeaderRedirect {
168172
if err == nil {
169173
// We're done
@@ -189,15 +193,13 @@ func (c *clusterConnection) Do(ctx context.Context, req driver.Request) (driver.
189193

190194
// Failed, try next server
191195
attempt++
192-
if specificServer != nil {
196+
if attempt > serverCount {
193197
// A specific server was specified, no failover.
194-
return nil, driver.WithStack(err)
195-
}
196-
if attempt > len(c.servers) {
198+
// or
197199
// We've tried all servers. Giving up.
198200
return nil, driver.WithStack(err)
199201
}
200-
s = c.getNextServer()
202+
server = c.getNextServer()
201203
}
202204
}
203205

@@ -321,11 +323,11 @@ func (c *clusterConnection) Protocols() driver.ProtocolSet {
321323
return result
322324
}
323325

324-
// getCurrentServer returns the currently used server.
325-
func (c *clusterConnection) getCurrentServer() driver.Connection {
326+
// getCurrentServer returns the currently used server and number of servers.
327+
func (c *clusterConnection) getCurrentServer() (driver.Connection, int) {
326328
c.mutex.RLock()
327329
defer c.mutex.RUnlock()
328-
return c.servers[c.current]
330+
return c.servers[c.current], len(c.servers)
329331
}
330332

331333
// getSpecificServer returns the server with the given endpoint.

0 commit comments

Comments
 (0)