diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 7a1a35b3b7..24ad6a3a51 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "net" + "os" "strings" "sync" "sync/atomic" @@ -56,7 +57,7 @@ type connection struct { nc net.Conn // When nil, the connection is closed. addr address.Address idleTimeout time.Duration - idleDeadline atomic.Value // Stores a time.Time + idleStart atomic.Value // Stores a time.Time desc description.Server helloRTT time.Duration compressor wiremessage.CompressorID @@ -520,25 +521,65 @@ func (c *connection) close() error { return err } +// closed returns true if the connection has been closed by the driver. func (c *connection) closed() bool { return atomic.LoadInt64(&c.state) == connDisconnected } +// isAlive returns true if the connection is alive and ready to be used for an +// operation. +// +// Note that the liveness check can be slow (at least 1ms), so isAlive only +// checks the liveness of the connection if it's been idle for at least 10 +// seconds. For frequently in-use connections, a network error during an +// operation will be the first indication of a dead connection. +func (c *connection) isAlive() bool { + if c.nc == nil { + return false + } + + // If the connection has been idle for less than 10 seconds, skip the + // liveness check. + // + // The 10-seconds idle bypass is based on the liveness check implementation + // in the Python Driver. That implementation uses 1 second as the idle + // threshold, but we chose to be more conservative in the Go Driver because + // this is new behavior with unknown side-effects. See + // https://github.com/mongodb/mongo-python-driver/blob/e6b95f65953e01e435004af069a6976473eaf841/pymongo/synchronous/pool.py#L983-L985 + idleStart, ok := c.idleStart.Load().(time.Time) + if !ok || idleStart.Add(10*time.Second).After(time.Now()) { + return true + } + + // Set a 1ms read deadline and attempt to read 1 byte from the connection. + // Expect it to block for 1ms then return a deadline exceeded error. If it + // returns any other error, the connection is not usable, so return false. + // If it doesn't return an error and actually reads data, the connection is + // also not usable, so return false. + // + // Note that we don't need to un-set the read deadline because the "read" + // and "write" methods always reset the deadlines. + err := c.nc.SetReadDeadline(time.Now().Add(1 * time.Millisecond)) + if err != nil { + return false + } + var b [1]byte + _, err = c.nc.Read(b[:]) + return errors.Is(err, os.ErrDeadlineExceeded) +} + func (c *connection) idleTimeoutExpired() bool { - now := time.Now() - if c.idleTimeout > 0 { - idleDeadline, ok := c.idleDeadline.Load().(time.Time) - if ok && now.After(idleDeadline) { - return true - } + if c.idleTimeout == 0 { + return false } - return false + idleStart, ok := c.idleStart.Load().(time.Time) + return ok && idleStart.Add(c.idleTimeout).Before(time.Now()) } -func (c *connection) bumpIdleDeadline() { +func (c *connection) bumpIdleStart() { if c.idleTimeout > 0 { - c.idleDeadline.Store(time.Now().Add(c.idleTimeout)) + c.idleStart.Store(time.Now()) } } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 88f14fc533..f4d0cd9d7d 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" @@ -278,7 +279,7 @@ func TestConnection(t *testing.T) { want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} err := conn.writeWireMessage(context.Background(), want) - noerr(t, err) + require.NoError(t, err) got := tnc.buf if !cmp.Equal(got, want) { t.Errorf("writeWireMessage did not write the proper bytes. got %v; want %v", got, want) @@ -471,7 +472,7 @@ func TestConnection(t *testing.T) { conn.cancellationListener = listener got, err := conn.readWireMessage(context.Background()) - noerr(t, err) + require.NoError(t, err) if !cmp.Equal(got, want) { t.Errorf("did not read full wire message. got %v; want %v", got, want) } @@ -1099,3 +1100,85 @@ func (tcl *testCancellationListener) assertCalledOnce(t *testing.T) { assert.Equal(t, 1, tcl.numListen, "expected Listen to be called once, got %d", tcl.numListen) assert.Equal(t, 1, tcl.numStopListening, "expected StopListening to be called once, got %d", tcl.numListen) } + +func TestConnection_IsAlive(t *testing.T) { + t.Parallel() + + t.Run("uninitialized", func(t *testing.T) { + t.Parallel() + + conn := newConnection("") + assert.False(t, + conn.isAlive(), + "expected isAlive for an uninitialized connection to always return false") + }) + + t.Run("connection open", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + // Keep the connection open until the end of the test. + <-cleanup + _ = nc.Close() + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.True(t, + conn.isAlive(), + "expected isAlive for an open connection to return true") + }) + + t.Run("connection closed", func(t *testing.T) { + t.Parallel() + + conns := make(chan net.Conn) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + conns <- nc + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + // Close the connection before calling isAlive. + nc := <-conns + err = nc.Close() + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.False(t, + conn.isAlive(), + "expected isAlive for a closed connection to return false") + }) + + t.Run("connection reads data", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + // Write some data to the connection before calling isAlive. + _, err := nc.Write([]byte{5, 0, 0, 0, 0}) + require.NoError(t, err) + + // Keep the connection open until the end of the test. + <-cleanup + _ = nc.Close() + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.False(t, + conn.isAlive(), + "expected isAlive for an open connection that reads data to return false") + }) +} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 499d2854dd..d6568e844f 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -169,8 +169,11 @@ type reason struct { // connectionPerished checks if a given connection is perished and should be removed from the pool. func connectionPerished(conn *connection) (reason, bool) { switch { - case conn.closed(): - // A connection would only be closed if it encountered a network error during an operation and closed itself. + case conn.closed() || !conn.isAlive(): + // A connection would only be closed if it encountered a network error + // during an operation and closed itself. If a connection is not alive + // (e.g. the connection was closed by the server-side), it's also + // considered a network error. return reason{ loggerConn: logger.ReasonConnClosedError, event: event.ReasonError, @@ -900,13 +903,15 @@ func (p *pool) checkInNoEvent(conn *connection) error { return nil } - // Bump the connection idle deadline here because we're about to make the connection "available". - // The idle deadline is used to determine when a connection has reached its max idle time and - // should be closed. A connection reaches its max idle time when it has been "available" in the - // idle connections stack for more than the configured duration (maxIdleTimeMS). Set it before - // we call connectionPerished(), which checks the idle deadline, because a newly "available" - // connection should never be perished due to max idle time. - conn.bumpIdleDeadline() + // Bump the connection idle start time here because we're about to make the + // connection "available". The idle start time is used to determine how long + // a connection has been idle and when it has reached its max idle time and + // should be closed. A connection reaches its max idle time when it has been + // "available" in the idle connections stack for more than the configured + // duration (maxIdleTimeMS). Set it before we call connectionPerished(), + // which checks the idle deadline, because a newly "available" connection + // should never be perished due to max idle time. + conn.bumpIdleStart() r, perished := connectionPerished(conn) if !perished && conn.pool.getState() == poolClosed { diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index f5e3ddfd6d..591399a30b 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -71,14 +71,14 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p1.ready() - noerr(t, err) + require.NoError(t, err) c, err := p1.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p2 := newPool(poolConfig{}) err = p2.ready() - noerr(t, err) + require.NoError(t, err) err = p2.closeConnection(c) assert.Equalf(t, ErrWrongPool, err, "expected ErrWrongPool error") @@ -97,7 +97,7 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 5; i++ { p.close(context.Background()) @@ -119,16 +119,16 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for i := range conns { err = p.checkIn(conns[i]) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") @@ -156,16 +156,16 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for i := 0; i < 2; i++ { err = p.checkIn(conns[i]) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") @@ -192,10 +192,10 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) closed := make(chan struct{}) started := make(chan struct{}) @@ -218,7 +218,7 @@ func TestPool(t *testing.T) { // connection pool. <-started _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) @@ -239,13 +239,13 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out 2 connections from the pool and add them to a conns slice. conns := make([]*connection, 2) for i := 0; i < 2; i++ { c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) conns[i] = c } @@ -253,10 +253,10 @@ func TestPool(t *testing.T) { // Check out a 3rd connection from the pool and immediately check it back in so there is // a mixture of in-use and idle connections. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) // Start a goroutine that waits for the pool to start closing, then checks in the // 2 in-use connections. Assert that both connections are still connected during @@ -269,7 +269,7 @@ func TestPool(t *testing.T) { assert.Equalf(t, connConnected, c.state, "expected conn to still be connected") err := p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } }() @@ -295,16 +295,16 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) c1 := &Connection{connection: c} err = c1.Close() - noerr(t, err) + require.NoError(t, err) }) }) t.Run("ready", func(t *testing.T) { @@ -325,12 +325,12 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) conns[i] = conn } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") @@ -339,17 +339,17 @@ func TestPool(t *testing.T) { p.clear(nil, nil) for _, conn := range conns { err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") assert.Equalf(t, 0, p.totalConnectionCount(), "should have 0 total connections") err = p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 3; i++ { _, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") assert.Equalf(t, 3, p.totalConnectionCount(), "should have 3 total connections") @@ -362,7 +362,7 @@ func TestPool(t *testing.T) { p := newPool(poolConfig{}) for i := 0; i < 5; i++ { err := p.ready() - noerr(t, err) + require.NoError(t, err) } p.close(context.Background()) @@ -382,27 +382,27 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) for i := 0; i < 100; i++ { err = p.ready() - noerr(t, err) + require.NoError(t, err) p.clear(nil, nil) } err = p.ready() - noerr(t, err) + require.NoError(t, err) c, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) }) @@ -421,12 +421,12 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -435,7 +435,7 @@ func TestPool(t *testing.T) { defer wg.Done() for i := 0; i < 1000; i++ { err := p.ready() - noerr(t, err) + require.NoError(t, err) } }() @@ -450,12 +450,12 @@ func TestPool(t *testing.T) { wg.Wait() err = p.ready() - noerr(t, err) + require.NoError(t, err) c, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) }) @@ -475,7 +475,7 @@ func TestPool(t *testing.T) { }) })) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) var want error = ConnectionError{Wrapped: dialErr, init: true} @@ -513,25 +513,25 @@ func TestPool(t *testing.T) { WithDialer(func(Dialer) Dialer { return d }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out a connection and assert that the idle timeout is properly set then check it // back into the pool. c1, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 1, d.lenopened(), "should have opened 1 connection") assert.Equalf(t, 1, p.totalConnectionCount(), "pool should have 1 total connection") assert.Equalf(t, time.Millisecond, c1.idleTimeout, "connection should have a 1ms idle timeout") err = p.checkIn(c1) - noerr(t, err) + require.NoError(t, err) // Sleep for more than the 1ms idle timeout and then try to check out a connection. // Expect that the previously checked-out connection is closed because it's idle and a // new connection is created. time.Sleep(50 * time.Millisecond) c2, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Assert that the connection pointers are not equal. Don't use "assert.NotEqual" because it asserts // non-equality of fields, possibly accessing some fields non-atomically and causing a race condition. assert.True(t, c1 != c2, "expected a new connection on 2nd check out after idle timeout expires") @@ -556,14 +556,14 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 100; i++ { c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 1, d.lenopened(), "should have opened 1 connection") @@ -584,7 +584,7 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) @@ -612,7 +612,7 @@ func TestPool(t *testing.T) { }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) assert.IsTypef(t, ConnectionError{}, err, "expected a ConnectionError") @@ -655,11 +655,11 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // check out first connection. _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Set a short timeout and check out again. ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -696,11 +696,11 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out the 1 connection that the pool will create. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Start a goroutine that tries to check out another connection with no timeout. Expect // this goroutine to block (wait in the wait queue) until the checked-out connection is @@ -711,7 +711,7 @@ func TestPool(t *testing.T) { defer wg.Done() _, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) }() // Run lots of check-out attempts with a low timeout and assert that each one fails with @@ -727,7 +727,7 @@ func TestPool(t *testing.T) { // Check-in the connection we checked out earlier and wait for the checkOut() goroutine // to resume. err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) wg.Wait() p.close(context.Background()) @@ -754,14 +754,14 @@ func TestPool(t *testing.T) { WithDialer(func(Dialer) Dialer { return d }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out two connections (MaxPoolSize) so that subsequent checkOut() calls should // block until a connection is checked back in or removed from the pool. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 2, d.lenopened(), "should have opened 2 connection") assert.Equalf(t, 2, p.totalConnectionCount(), "pool should have 2 total connection") assert.Equalf(t, 0, p.availableConnectionCount(), "pool should have 0 idle connection") @@ -786,10 +786,10 @@ func TestPool(t *testing.T) { c.close() start = time.Now() err := p.checkIn(c) - noerr(t, err) + require.NoError(t, err) }() _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.WithinDurationf( t, time.Now(), @@ -820,11 +820,11 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out first connection. _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Use a canceled context to check out another connection. cancelCtx, cancel := context.WithCancel(context.Background()) @@ -839,6 +839,79 @@ func TestPool(t *testing.T) { assert.Containsf(t, err.Error(), "canceled", `expected error message to contain "canceled"`) } + p.close(context.Background()) + }) + t.Run("discards connections closed by the server side", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + + ncs := make(chan net.Conn, 2) + addr := bootstrapConnections(t, 2, func(nc net.Conn) { + // Send all "server-side" connections to a channel so we can + // interact with them during the test. + ncs <- nc + + <-cleanup + _ = nc.Close() + }) + + d := newdialer(&net.Dialer{}) + p := newPool(poolConfig{ + Address: address.Address(addr.String()), + }, WithDialer(func(Dialer) Dialer { return d })) + err := p.ready() + require.NoError(t, err) + + // Add 1 idle connection to the pool by checking-out and checking-in + // a connection. + conn, err := p.checkOut(context.Background()) + require.NoError(t, err) + err = p.checkIn(conn) + require.NoError(t, err) + assertConnectionsOpened(t, d, 1) + assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") + assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") + + // Make that connection appear as if it's been idle for a minute. + conn.idleStart.Store(time.Now().Add(-1 * time.Minute)) + + // Close the "server-side" of the connection we just created. The idle + // connection in the pool is now unusable because the "server-side" + // closed it. + nc := <-ncs + err = nc.Close() + require.NoError(t, err) + + // In a separate goroutine, write a valid wire message to the 2nd + // connection that's about to be created. Stop waiting for a 2nd + // connection after 100ms to prevent leaking a goroutine. + go func() { + select { + case nc := <-ncs: + _, err = nc.Write([]byte{5, 0, 0, 0, 0}) + require.NoError(t, err, "Write error") + case <-time.After(100 * time.Millisecond): + } + }() + + // Check out a connection and try to read from it. Expect the pool to + // discard the connection that was closed by the "server-side" and + // return a newly created connection instead. + conn, err = p.checkOut(context.Background()) + require.NoError(t, err) + msg, err := conn.readWireMessage(context.Background()) + require.NoError(t, err) + assert.Equal(t, []byte{5, 0, 0, 0, 0}, msg) + + err = p.checkIn(conn) + require.NoError(t, err) + + assertConnectionsOpened(t, d, 2) + assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") + assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") + p.close(context.Background()) }) }) @@ -860,15 +933,15 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, p.availableConnectionCount(), "should be no idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) assert.NotNilf(t, err, "expected an error trying to return the same conn to the pool twice") @@ -894,10 +967,10 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should have 1 total connection in pool") @@ -905,7 +978,7 @@ func TestPool(t *testing.T) { p.close(context.Background()) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 1, d.lenclosed(), "should have closed 1 connection") assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 idle connections in pool") assert.Equalf(t, 0, p.totalConnectionCount(), "should have 0 total connection in pool") @@ -925,14 +998,14 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }) err := p1.ready() - noerr(t, err) + require.NoError(t, err) c, err := p1.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p2 := newPool(poolConfig{}) err = p2.ready() - noerr(t, err) + require.NoError(t, err) err = p2.checkIn(c) assert.Equalf(t, ErrWrongPool, err, "expected ErrWrongPool error") @@ -957,18 +1030,18 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) defer p.close(context.Background()) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Sleep for 110ms, which will exceed the 100ms connection idle timeout. Then check the // connection back in and expect that it is not closed because checkIn() should bump the // connection idle deadline. time.Sleep(110 * time.Millisecond) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") assert.Equalf(t, 1, p.availableConnectionCount(), "should have 1 idle connections in pool") @@ -992,7 +1065,7 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) defer p.close(context.Background()) // Wait for maintain() to open 3 connections. @@ -1004,7 +1077,7 @@ func TestPool(t *testing.T) { // and tries to create a new connection. time.Sleep(100 * time.Millisecond) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assertConnectionsClosed(t, d, 3) assert.Equalf(t, 4, d.lenopened(), "should have opened 4 connections") @@ -1032,7 +1105,7 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 3) assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") @@ -1058,7 +1131,7 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 2) assert.Equalf(t, 2, p.availableConnectionCount(), "should be 2 idle connections in pool") @@ -1084,18 +1157,18 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out and check in 3 connections. Assert that there are 3 total and 3 idle // connections in the pool. conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for _, c := range conns { err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") @@ -1107,7 +1180,7 @@ func TestPool(t *testing.T) { p.idleMu.Lock() for i := 0; i < 2; i++ { p.idleConns[i].idleTimeout = time.Millisecond - p.idleConns[i].idleDeadline.Store(time.Now().Add(-1 * time.Hour)) + p.idleConns[i].idleStart.Store(time.Now().Add(-1 * time.Hour)) } p.idleMu.Unlock() assertConnectionsClosed(t, d, 2) @@ -1135,7 +1208,7 @@ func TestPool(t *testing.T) { ConnectTimeout: defaultConnectionTimeout, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 3) assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") assert.Equalf(t, 3, p.totalConnectionCount(), "should be 3 total connection in pool") @@ -1143,7 +1216,7 @@ func TestPool(t *testing.T) { p.idleMu.Lock() for i := 0; i < 2; i++ { p.idleConns[i].idleTimeout = time.Millisecond - p.idleConns[i].idleDeadline.Store(time.Now().Add(-1 * time.Hour)) + p.idleConns[i].idleStart.Store(time.Now().Add(-1 * time.Hour)) } p.idleMu.Unlock() assertConnectionsClosed(t, d, 2) @@ -1185,7 +1258,7 @@ func TestBackgroundRead(t *testing.T) { }() _, err := nc.Write([]byte{10, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1193,10 +1266,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1225,7 +1298,7 @@ func TestBackgroundRead(t *testing.T) { // Wait until the operation times out, then write an full message. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1233,10 +1306,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1245,7 +1318,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1272,7 +1345,7 @@ func TestBackgroundRead(t *testing.T) { // Wait until the operation times out, then write an incomplete head. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1280,10 +1353,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1292,7 +1365,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1324,7 +1397,7 @@ func TestBackgroundRead(t *testing.T) { // message. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1332,10 +1405,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1344,7 +1417,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1374,11 +1447,11 @@ func TestBackgroundRead(t *testing.T) { var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) + require.NoError(t, err) time.Sleep(timeout * 2) // write a complete message _, err = nc.Write([]byte{2, 3, 4}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1386,10 +1459,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1398,7 +1471,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1424,11 +1497,11 @@ func TestBackgroundRead(t *testing.T) { var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) + require.NoError(t, err) time.Sleep(timeout * 2) // write an incomplete message _, err = nc.Write([]byte{2}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1436,10 +1509,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.WithTimeout(context.Background(), &timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1448,7 +1521,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index b7ffd079a0..5ab5692840 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -619,11 +619,11 @@ func TestServer(t *testing.T) { })) s.state = serverConnected err := s.pool.ready() - noerr(t, err) + require.NoError(t, err) defer s.pool.close(context.Background()) conn, err := s.Connection(context.Background()) - noerr(t, err) + require.NoError(t, err) if d.lenopened() != 1 { t.Errorf("Should have opened 1 connections, but didn't. got %d; want %d", d.lenopened(), 1) } @@ -645,7 +645,7 @@ func TestServer(t *testing.T) { <-ch runtime.Gosched() err = conn.Close() - noerr(t, err) + require.NoError(t, err) wg.Wait() close(cleanup) }) diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index 5a96d135ac..3b4306e606 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -17,6 +17,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/internal/serverselector" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" ) @@ -29,7 +30,7 @@ func TestTopologyErrors(t *testing.T) { t.Run("errors are wrapped", func(t *testing.T) { t.Run("server selection error", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) desc := description.Topology{ diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index 0f445b5d32..5aa856c443 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -29,14 +29,6 @@ import ( const testTimeout = 2 * time.Second -func noerr(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Errorf("Unexpected error: %v", err) - t.FailNow() - } -} - func compareErrors(err1, err2 error) bool { if err1 == nil && err2 == nil { return true @@ -66,7 +58,7 @@ func TestServerSelection(t *testing.T) { t.Run("Success", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Servers: []description.Server{ {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, @@ -78,7 +70,7 @@ func TestServerSelection(t *testing.T) { subCh <- desc srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, selectFirst) - noerr(t, err) + require.NoError(t, err) if len(srvs) != 1 { t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) } @@ -88,7 +80,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Compatibility Error Min Version Too High", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Kind: description.TopologyKindSingle, Servers: []description.Server{ @@ -111,7 +103,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Compatibility Error Max Version Too Low", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Kind: description.TopologyKindSingle, Servers: []description.Server{ @@ -134,7 +126,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Updated", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{Servers: []description.Server{}} subCh := make(chan description.Topology, 1) subCh <- desc @@ -142,7 +134,7 @@ func TestServerSelection(t *testing.T) { resp := make(chan []description.Server) go func() { srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, selectFirst) - noerr(t, err) + require.NoError(t, err) resp <- srvs }() @@ -182,7 +174,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -211,10 +203,10 @@ func TestServerSelection(t *testing.T) { }) t.Run("findServer returns topology kind", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id, defaultConnectionTimeout) - noerr(t, err) + require.NoError(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) desc.Kind = description.TopologyKindSingle @@ -223,7 +215,7 @@ func TestServerSelection(t *testing.T) { selected := description.Server{Addr: address.Address("one")} ss, err := topo.FindServer(selected) - noerr(t, err) + require.NoError(t, err) if ss.Kind != description.TopologyKindSingle { t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.TopologyKindSingle) } @@ -231,7 +223,7 @@ func TestServerSelection(t *testing.T) { t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) { // Assert that the server selection fast path does not create a Subscription or check for timeout errors. topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) primaryAddr := address.Address("one") @@ -243,7 +235,7 @@ func TestServerSelection(t *testing.T) { topo.desc.Store(desc) for _, srv := range desc.Servers { s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id, defaultConnectionTimeout) - noerr(t, err) + require.NoError(t, err) topo.servers[srv.Addr] = s } @@ -253,13 +245,13 @@ func TestServerSelection(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() selectedServer, err := topo.SelectServer(ctx, &serverselector.Write{}) - noerr(t, err) + require.NoError(t, err) selectedAddr := selectedServer.(*SelectedServer).address assert.Equal(t, primaryAddr, selectedAddr, "expected address %v, got %v", primaryAddr, selectedAddr) }) t.Run("default to selecting from subscription if fast path fails", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) desc := description.Topology{ @@ -278,7 +270,7 @@ func TestSessionTimeout(t *testing.T) { t.Run("UpdateSessionTimeout", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.fsm.Servers = []description.Server{ { @@ -305,7 +297,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("MultipleUpdates", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.fsm.Kind = description.TopologyKindReplicaSetWithPrimary topo.servers["foo"] = nil topo.servers["bar"] = nil @@ -348,7 +340,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("NoUpdate", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ @@ -390,7 +382,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("TimeoutDataBearing", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ @@ -432,7 +424,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("MixedSessionSupport", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.fsm.Kind = description.TopologyKindReplicaSetWithPrimary topo.servers["one"] = nil topo.servers["two"] = nil