Skip to content

Commit b047752

Browse files
author
Benjamin Rewis
committed
GODRIVER-1849 Add new WaitQueueTimeoutError to wrap context error (#568)
1 parent a058925 commit b047752

File tree

4 files changed

+59
-6
lines changed

4 files changed

+59
-6
lines changed

x/mongo/driver/topology/CMAP_spec_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func runCMAPTest(t *testing.T, testFileName string) {
160160
for len(testInfo.backgroundThreadErrors) > 0 {
161161
bgErr := <-testInfo.backgroundThreadErrors
162162
errs = append(errs, bgErr)
163-
if bgErr != nil && strings.ToLower(test.Error.Message) == bgErr.Error() {
163+
if bgErr != nil && strings.Contains(bgErr.Error(), strings.ToLower(test.Error.Message)) {
164164
erroredCorrectly = true
165165
break
166166
}

x/mongo/driver/topology/errors.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package topology
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
)
46

57
// ConnectionError represents a connection error.
68
type ConnectionError struct {
@@ -25,3 +27,20 @@ func (e ConnectionError) Error() string {
2527
func (e ConnectionError) Unwrap() error {
2628
return e.Wrapped
2729
}
30+
31+
// WaitQueueTimeoutError represents a timeout when requesting a connection from the pool
32+
type WaitQueueTimeoutError struct {
33+
Wrapped error
34+
}
35+
36+
func (w WaitQueueTimeoutError) Error() string {
37+
errorMsg := "timed out while checking out a connection from connection pool"
38+
if w.Wrapped != nil {
39+
return fmt.Sprintf("%s: %s", errorMsg, w.Wrapped.Error())
40+
}
41+
return errorMsg
42+
}
43+
44+
func (w WaitQueueTimeoutError) Unwrap() error {
45+
return w.Wrapped
46+
}

x/mongo/driver/topology/pool.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ var ErrConnectionClosed = ConnectionError{ConnectionID: "<closed>", message: "co
3131
// ErrWrongPool is return when a connection is returned to a pool it doesn't belong to.
3232
var ErrWrongPool = PoolError("connection does not belong to this pool")
3333

34-
// ErrWaitQueueTimeout is returned when the request to get a connection from the pool timesout when on the wait queue
35-
var ErrWaitQueueTimeout = PoolError("timed out while checking out a connection from connection pool")
36-
3734
// PoolError is an error returned from a Pool method.
3835
type PoolError string
3936

@@ -340,7 +337,10 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
340337
Reason: event.ReasonTimedOut,
341338
})
342339
}
343-
return nil, ErrWaitQueueTimeout
340+
errWaitQueueTimeout := WaitQueueTimeoutError{
341+
Wrapped: ctx.Err(),
342+
}
343+
return nil, errWaitQueueTimeout
344344
}
345345

346346
// This loop is so that we don't end up with more than maxPoolSize connections if p.conns.Maintain runs between

x/mongo/driver/topology/pool_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,40 @@ func TestPool(t *testing.T) {
689689
noerr(t, err)
690690
})
691691
})
692+
t.Run("wait queue timeout error", func(t *testing.T) {
693+
cleanup := make(chan struct{})
694+
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
695+
<-cleanup
696+
_ = nc.Close()
697+
})
698+
d := newdialer(&net.Dialer{})
699+
pc := poolConfig{
700+
Address: address.Address(addr.String()),
701+
MaxPoolSize: 1,
702+
}
703+
p, err := newPool(pc, WithDialer(func(Dialer) Dialer { return d }))
704+
noerr(t, err)
705+
err = p.connect()
706+
noerr(t, err)
707+
708+
// get first connection.
709+
_, err = p.get(context.Background())
710+
noerr(t, err)
711+
712+
// Set a short timeout and get again.
713+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
714+
defer cancel()
715+
_, err = p.get(ctx)
716+
assert.NotNil(t, err, "expected a WaitQueueTimeout; got nil")
717+
718+
// Assert that error received is WaitQueueTimeoutError with context deadline exceeded.
719+
wqtErr, ok := err.(WaitQueueTimeoutError)
720+
assert.True(t, ok, "expected a WaitQueueTimeoutError; got %v", err)
721+
assert.True(t, wqtErr.Unwrap() == context.DeadlineExceeded,
722+
"expected a timeout error; got %v", wqtErr)
723+
724+
close(cleanup)
725+
})
692726
}
693727

694728
type sleepDialer struct {

0 commit comments

Comments
 (0)