From d9268fd0cb64efb7dcb1eeb59377b815e921ea1c Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Thu, 5 Sep 2024 19:03:34 -0700 Subject: [PATCH 1/4] GODRIVER-3156 Detect and discard closed idle connections. --- x/mongo/driver/topology/connection.go | 52 ++++++++++++++---- x/mongo/driver/topology/pool.go | 16 +++--- x/mongo/driver/topology/pool_test.go | 77 ++++++++++++++++++++++++++- 3 files changed, 125 insertions(+), 20 deletions(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 49a613aef8..217fc5ce5c 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net" + "os" "strings" "sync" "sync/atomic" @@ -54,7 +55,7 @@ type connection struct { nc net.Conn // When nil, the connection is closed. addr address.Address idleTimeout time.Duration - idleDeadline atomic.Value // Stores a time.Time + idleStart atomic.Value // Stores a time.Time readTimeout time.Duration writeTimeout time.Duration desc description.Server @@ -512,24 +513,53 @@ func (c *connection) close() error { } func (c *connection) closed() bool { - return atomic.LoadInt64(&c.state) == connDisconnected + if atomic.LoadInt64(&c.state) == connDisconnected { + return true + } + + // If the connection has been idle for less than 10 seconds, skip the liveness + // check. + idleStart, ok := c.idleStart.Load().(time.Time) + if !ok || idleStart.Add(10*time.Second).After(time.Now()) { + return false + } + + // Set a 1ms read deadline and attempt to read 1 byte from the wire. Expect + // it to block for 1ms then return a deadline exceeded error. If it returns + // any other error, the connection is not usable, so close it and return + // true. If it doesn't return an error and actually reads data, the + // connection is also not usable, so close it and return true. + err := c.nc.SetReadDeadline(time.Now().Add(1 * time.Millisecond)) + if err != nil { + _ = c.close() + return true + } + + var b [1]byte + _, err = c.nc.Read(b[:]) + if !errors.Is(err, os.ErrDeadlineExceeded) { + _ = c.close() + return true + } + + // We don't need to un-set the read deadline because the "read" and "write" + // methods always reset the deadlines. + + return false } func (c *connection) idleTimeoutExpired() bool { - now := time.Now() - if c.idleTimeout > 0 { - idleDeadline, ok := c.idleDeadline.Load().(time.Time) - if ok && now.After(idleDeadline) { - return true - } + if c.idleTimeout == 0 { + return false } - return false + idleStart, ok := c.idleStart.Load().(time.Time) + return ok && idleStart.Add(c.idleTimeout).Before(time.Now()) } -func (c *connection) bumpIdleDeadline() { +func (c *connection) bumpIdleStart() { if c.idleTimeout > 0 { - c.idleDeadline.Store(time.Now().Add(c.idleTimeout)) + c.idleStart.Store(time.Now()) } } diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 52461eb681..3ed5335aa1 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -892,13 +892,15 @@ func (p *pool) checkInNoEvent(conn *connection) error { return nil } - // Bump the connection idle deadline here because we're about to make the connection "available". - // The idle deadline is used to determine when a connection has reached its max idle time and - // should be closed. A connection reaches its max idle time when it has been "available" in the - // idle connections stack for more than the configured duration (maxIdleTimeMS). Set it before - // we call connectionPerished(), which checks the idle deadline, because a newly "available" - // connection should never be perished due to max idle time. - conn.bumpIdleDeadline() + // Bump the connection idle start time here because we're about to make the + // connection "available". The idle start time is used to determine how long + // a connection has been idle and when it has reached its max idle time and + // should be closed. A connection reaches its max idle time when it has been + // "available" in the idle connections stack for more than the configured + // duration (maxIdleTimeMS). Set it before we call connectionPerished(), + // which checks the idle deadline, because a newly "available" connection + // should never be perished due to max idle time. + conn.bumpIdleStart() r, perished := connectionPerished(conn) if !perished && conn.pool.getState() == poolClosed { diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index bc7115ee2c..712da5ccdd 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -816,6 +816,79 @@ func TestPool(t *testing.T) { assert.Containsf(t, err.Error(), "canceled", `expected error message to contain "canceled"`) } + p.close(context.Background()) + }) + t.Run("discards connections closed by the server side", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + + ncs := make(chan net.Conn, 2) + addr := bootstrapConnections(t, 2, func(nc net.Conn) { + // Send all "server-side" connections to a channel so we can + // interact with them during the test. + ncs <- nc + + <-cleanup + _ = nc.Close() + }) + + d := newdialer(&net.Dialer{}) + p := newPool(poolConfig{ + Address: address.Address(addr.String()), + }, WithDialer(func(Dialer) Dialer { return d })) + err := p.ready() + noerr(t, err) + + // Add 1 idle connection to the pool by checking-out and checking-in + // a connection. + conn, err := p.checkOut(context.Background()) + noerr(t, err) + err = p.checkIn(conn) + noerr(t, err) + assertConnectionsOpened(t, d, 1) + assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") + assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") + + // Make that connection appear as if it's been idle for a minute. + conn.idleStart.Store(time.Now().Add(-1 * time.Minute)) + + // Close the "server-side" of the connection we just created. The idle + // connection in the pool is now unusable because the "server-side" + // closed it. + nc := <-ncs + err = nc.Close() + noerr(t, err) + + // In a separate goroutine, write a valid wire message to the 2nd + // connection that's about to be created. Stop waiting for a 2nd + // connection after 100ms to prevent leaking a goroutine. + go func() { + select { + case nc := <-ncs: + _, err = nc.Write([]byte{5, 0, 0, 0, 0}) + noerr(t, err) + case <-time.After(100 * time.Millisecond): + } + }() + + // Check out a connection and try to read from it. Expect the pool to + // discard the connection that was closed by the "server-side" and + // return a newly created connection instead. + conn, err = p.checkOut(context.Background()) + noerr(t, err) + msg, err := conn.readWireMessage(context.Background()) + noerr(t, err) + assert.Equal(t, []byte{5, 0, 0, 0, 0}, msg) + + err = p.checkIn(conn) + noerr(t, err) + + assertConnectionsOpened(t, d, 2) + assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") + assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") + p.close(context.Background()) }) }) @@ -1076,7 +1149,7 @@ func TestPool(t *testing.T) { p.idleMu.Lock() for i := 0; i < 2; i++ { p.idleConns[i].idleTimeout = time.Millisecond - p.idleConns[i].idleDeadline.Store(time.Now().Add(-1 * time.Hour)) + p.idleConns[i].idleStart.Store(time.Now().Add(-1 * time.Hour)) } p.idleMu.Unlock() assertConnectionsClosed(t, d, 2) @@ -1111,7 +1184,7 @@ func TestPool(t *testing.T) { p.idleMu.Lock() for i := 0; i < 2; i++ { p.idleConns[i].idleTimeout = time.Millisecond - p.idleConns[i].idleDeadline.Store(time.Now().Add(-1 * time.Hour)) + p.idleConns[i].idleStart.Store(time.Now().Add(-1 * time.Hour)) } p.idleMu.Unlock() assertConnectionsClosed(t, d, 2) From 39937d9cfbfb8798f067d8d77373aa1a19339cc3 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:51:20 -0700 Subject: [PATCH 2/4] Separate liveness check method, other review changes. --- x/mongo/driver/topology/connection.go | 47 ++--- x/mongo/driver/topology/connection_test.go | 5 +- x/mongo/driver/topology/pool.go | 7 +- x/mongo/driver/topology/pool_test.go | 184 +++++++++--------- x/mongo/driver/topology/server_test.go | 6 +- .../driver/topology/topology_errors_test.go | 3 +- x/mongo/driver/topology/topology_test.go | 58 +++--- 7 files changed, 154 insertions(+), 156 deletions(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 217fc5ce5c..25fe3ebd1b 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -512,40 +512,41 @@ func (c *connection) close() error { return err } +// closed returns true if the connection has been closed by the driver. func (c *connection) closed() bool { - if atomic.LoadInt64(&c.state) == connDisconnected { - return true - } + return atomic.LoadInt64(&c.state) == connDisconnected +} - // If the connection has been idle for less than 10 seconds, skip the liveness - // check. +// isAlive returns true if the connection is alive and ready to be used for an +// operation. +// +// Note that the liveness check can be slow (at least 1ms), so isAlive only +// checks the liveness of the connection if it's been idle for at least 10 +// seconds. For frequently in-use connections, a network error during an +// operation will be the first indication of a dead connection. +func (c *connection) isAlive() bool { + // If the connection has been idle for less than 10 seconds, skip the + // liveness check. idleStart, ok := c.idleStart.Load().(time.Time) if !ok || idleStart.Add(10*time.Second).After(time.Now()) { - return false + return true } - // Set a 1ms read deadline and attempt to read 1 byte from the wire. Expect - // it to block for 1ms then return a deadline exceeded error. If it returns - // any other error, the connection is not usable, so close it and return - // true. If it doesn't return an error and actually reads data, the - // connection is also not usable, so close it and return true. + // Set a 1ms read deadline and attempt to read 1 byte from the connection. + // Expect it to block for 1ms then return a deadline exceeded error. If it + // returns any other error, the connection is not usable, so return false. + // If it doesn't return an error and actually reads data, the connection is + // also not usable, so return false. + // + // Note that we don't need to un-set the read deadline because the "read" + // and "write" methods always reset the deadlines. err := c.nc.SetReadDeadline(time.Now().Add(1 * time.Millisecond)) if err != nil { - _ = c.close() - return true + return false } - var b [1]byte _, err = c.nc.Read(b[:]) - if !errors.Is(err, os.ErrDeadlineExceeded) { - _ = c.close() - return true - } - - // We don't need to un-set the read deadline because the "read" and "write" - // methods always reset the deadlines. - - return false + return errors.Is(err, os.ErrDeadlineExceeded) } func (c *connection) idleTimeoutExpired() bool { diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 946f74d8f2..6c43c38170 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -427,7 +428,7 @@ func TestConnection(t *testing.T) { want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} err := conn.writeWireMessage(context.Background(), want) - noerr(t, err) + require.NoError(t, err) got := tnc.buf if !cmp.Equal(got, want) { t.Errorf("writeWireMessage did not write the proper bytes. got %v; want %v", got, want) @@ -607,7 +608,7 @@ func TestConnection(t *testing.T) { conn.cancellationListener = listener got, err := conn.readWireMessage(context.Background()) - noerr(t, err) + require.NoError(t, err) if !cmp.Equal(got, want) { t.Errorf("did not read full wire message. got %v; want %v", got, want) } diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 3ed5335aa1..664b4f66df 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -166,8 +166,11 @@ type reason struct { // connectionPerished checks if a given connection is perished and should be removed from the pool. func connectionPerished(conn *connection) (reason, bool) { switch { - case conn.closed(): - // A connection would only be closed if it encountered a network error during an operation and closed itself. + case conn.closed() || !conn.isAlive(): + // A connection would only be closed if it encountered a network error + // during an operation and closed itself. If a connection is not alive + // (e.g. the connection was closed by the server-side), it's also + // considered a network error. return reason{ loggerConn: logger.ReasonConnClosedError, event: event.ReasonError, diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 712da5ccdd..e120aadd94 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -69,14 +69,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p1.ready() - noerr(t, err) + require.NoError(t, err) c, err := p1.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p2 := newPool(poolConfig{}) err = p2.ready() - noerr(t, err) + require.NoError(t, err) err = p2.closeConnection(c) assert.Equalf(t, ErrWrongPool, err, "expected ErrWrongPool error") @@ -93,7 +93,7 @@ func TestPool(t *testing.T) { p := newPool(poolConfig{}) err := p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 5; i++ { p.close(context.Background()) @@ -114,16 +114,16 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for i := range conns { err = p.checkIn(conns[i]) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") @@ -150,16 +150,16 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for i := 0; i < 2; i++ { err = p.checkIn(conns[i]) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") @@ -185,10 +185,10 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) closed := make(chan struct{}) started := make(chan struct{}) @@ -211,7 +211,7 @@ func TestPool(t *testing.T) { // connection pool. <-started _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) @@ -231,13 +231,13 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out 2 connections from the pool and add them to a conns slice. conns := make([]*connection, 2) for i := 0; i < 2; i++ { c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) conns[i] = c } @@ -245,10 +245,10 @@ func TestPool(t *testing.T) { // Check out a 3rd connection from the pool and immediately check it back in so there is // a mixture of in-use and idle connections. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) // Start a goroutine that waits for the pool to start closing, then checks in the // 2 in-use connections. Assert that both connections are still connected during @@ -261,7 +261,7 @@ func TestPool(t *testing.T) { assert.Equalf(t, connConnected, c.state, "expected conn to still be connected") err := p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } }() @@ -286,16 +286,16 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) c1 := &Connection{connection: c} err = c1.Close() - noerr(t, err) + require.NoError(t, err) }) }) t.Run("ready", func(t *testing.T) { @@ -315,12 +315,12 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) conns := make([]*connection, 3) for i := range conns { conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) conns[i] = conn } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") @@ -329,17 +329,17 @@ func TestPool(t *testing.T) { p.clear(nil, nil) for _, conn := range conns { err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") assert.Equalf(t, 0, p.totalConnectionCount(), "should have 0 total connections") err = p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 3; i++ { _, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 available connections") assert.Equalf(t, 3, p.totalConnectionCount(), "should have 3 total connections") @@ -352,7 +352,7 @@ func TestPool(t *testing.T) { p := newPool(poolConfig{}) for i := 0; i < 5; i++ { err := p.ready() - noerr(t, err) + require.NoError(t, err) } p.close(context.Background()) @@ -371,27 +371,27 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) for i := 0; i < 100; i++ { err = p.ready() - noerr(t, err) + require.NoError(t, err) p.clear(nil, nil) } err = p.ready() - noerr(t, err) + require.NoError(t, err) c, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) }) @@ -409,12 +409,12 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -423,7 +423,7 @@ func TestPool(t *testing.T) { defer wg.Done() for i := 0; i < 1000; i++ { err := p.ready() - noerr(t, err) + require.NoError(t, err) } }() @@ -438,12 +438,12 @@ func TestPool(t *testing.T) { wg.Wait() err = p.ready() - noerr(t, err) + require.NoError(t, err) c, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) }) @@ -461,7 +461,7 @@ func TestPool(t *testing.T) { }) })) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) var want error = ConnectionError{Wrapped: dialErr, init: true} @@ -498,25 +498,25 @@ func TestPool(t *testing.T) { WithDialer(func(Dialer) Dialer { return d }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out a connection and assert that the idle timeout is properly set then check it // back into the pool. c1, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 1, d.lenopened(), "should have opened 1 connection") assert.Equalf(t, 1, p.totalConnectionCount(), "pool should have 1 total connection") assert.Equalf(t, time.Millisecond, c1.idleTimeout, "connection should have a 1ms idle timeout") err = p.checkIn(c1) - noerr(t, err) + require.NoError(t, err) // Sleep for more than the 1ms idle timeout and then try to check out a connection. // Expect that the previously checked-out connection is closed because it's idle and a // new connection is created. time.Sleep(50 * time.Millisecond) c2, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Assert that the connection pointers are not equal. Don't use "assert.NotEqual" because it asserts // non-equality of fields, possibly accessing some fields non-atomically and causing a race condition. assert.True(t, c1 != c2, "expected a new connection on 2nd check out after idle timeout expires") @@ -540,14 +540,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) for i := 0; i < 100; i++ { c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 1, d.lenopened(), "should have opened 1 connection") @@ -567,7 +567,7 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) p.close(context.Background()) @@ -593,7 +593,7 @@ func TestPool(t *testing.T) { }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) assert.IsTypef(t, ConnectionError{}, err, "expected a ConnectionError") @@ -635,11 +635,11 @@ func TestPool(t *testing.T) { MaxPoolSize: 1, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // check out first connection. _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Set a short timeout and check out again. ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -675,11 +675,11 @@ func TestPool(t *testing.T) { MaxPoolSize: 1, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out the 1 connection that the pool will create. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Start a goroutine that tries to check out another connection with no timeout. Expect // this goroutine to block (wait in the wait queue) until the checked-out connection is @@ -690,7 +690,7 @@ func TestPool(t *testing.T) { defer wg.Done() _, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) }() // Run lots of check-out attempts with a low timeout and assert that each one fails with @@ -706,7 +706,7 @@ func TestPool(t *testing.T) { // Check-in the connection we checked out earlier and wait for the checkOut() goroutine // to resume. err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) wg.Wait() p.close(context.Background()) @@ -732,14 +732,14 @@ func TestPool(t *testing.T) { WithDialer(func(Dialer) Dialer { return d }), ) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out two connections (MaxPoolSize) so that subsequent checkOut() calls should // block until a connection is checked back in or removed from the pool. c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 2, d.lenopened(), "should have opened 2 connection") assert.Equalf(t, 2, p.totalConnectionCount(), "pool should have 2 total connection") assert.Equalf(t, 0, p.availableConnectionCount(), "pool should have 0 idle connection") @@ -764,10 +764,10 @@ func TestPool(t *testing.T) { c.close() start = time.Now() err := p.checkIn(c) - noerr(t, err) + require.NoError(t, err) }() _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.WithinDurationf( t, time.Now(), @@ -797,11 +797,11 @@ func TestPool(t *testing.T) { MaxPoolSize: 1, }) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out first connection. _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Use a canceled context to check out another connection. cancelCtx, cancel := context.WithCancel(context.Background()) @@ -839,14 +839,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Add 1 idle connection to the pool by checking-out and checking-in // a connection. conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 1) assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") @@ -859,7 +859,7 @@ func TestPool(t *testing.T) { // closed it. nc := <-ncs err = nc.Close() - noerr(t, err) + require.NoError(t, err) // In a separate goroutine, write a valid wire message to the 2nd // connection that's about to be created. Stop waiting for a 2nd @@ -868,7 +868,7 @@ func TestPool(t *testing.T) { select { case nc := <-ncs: _, err = nc.Write([]byte{5, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err, "Write error") case <-time.After(100 * time.Millisecond): } }() @@ -877,13 +877,13 @@ func TestPool(t *testing.T) { // discard the connection that was closed by the "server-side" and // return a newly created connection instead. conn, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) msg, err := conn.readWireMessage(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equal(t, []byte{5, 0, 0, 0, 0}, msg) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 2) assert.Equalf(t, 1, p.availableConnectionCount(), "should be 1 idle connections in pool") @@ -909,15 +909,15 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, p.availableConnectionCount(), "should be no idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should be 1 total connection in pool") err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) err = p.checkIn(c) assert.NotNilf(t, err, "expected an error trying to return the same conn to the pool twice") @@ -942,10 +942,10 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 idle connections in pool") assert.Equalf(t, 1, p.totalConnectionCount(), "should have 1 total connection in pool") @@ -953,7 +953,7 @@ func TestPool(t *testing.T) { p.close(context.Background()) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 1, d.lenclosed(), "should have closed 1 connection") assert.Equalf(t, 0, p.availableConnectionCount(), "should have 0 idle connections in pool") assert.Equalf(t, 0, p.totalConnectionCount(), "should have 0 total connection in pool") @@ -972,14 +972,14 @@ func TestPool(t *testing.T) { Address: address.Address(addr.String()), }) err := p1.ready() - noerr(t, err) + require.NoError(t, err) c, err := p1.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) p2 := newPool(poolConfig{}) err = p2.ready() - noerr(t, err) + require.NoError(t, err) err = p2.checkIn(c) assert.Equalf(t, ErrWrongPool, err, "expected ErrWrongPool error") @@ -1003,18 +1003,18 @@ func TestPool(t *testing.T) { MaxIdleTime: 100 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) defer p.close(context.Background()) c, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) // Sleep for 110ms, which will exceed the 100ms connection idle timeout. Then check the // connection back in and expect that it is not closed because checkIn() should bump the // connection idle deadline. time.Sleep(110 * time.Millisecond) err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) assert.Equalf(t, 0, d.lenclosed(), "should have closed 0 connections") assert.Equalf(t, 1, p.availableConnectionCount(), "should have 1 idle connections in pool") @@ -1037,7 +1037,7 @@ func TestPool(t *testing.T) { MaxIdleTime: 10 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) defer p.close(context.Background()) // Wait for maintain() to open 3 connections. @@ -1049,7 +1049,7 @@ func TestPool(t *testing.T) { // and tries to create a new connection. time.Sleep(100 * time.Millisecond) _, err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) assertConnectionsClosed(t, d, 3) assert.Equalf(t, 4, d.lenopened(), "should have opened 4 connections") @@ -1076,7 +1076,7 @@ func TestPool(t *testing.T) { MinPoolSize: 3, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 3) assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") @@ -1101,7 +1101,7 @@ func TestPool(t *testing.T) { MaxPoolSize: 2, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 2) assert.Equalf(t, 2, p.availableConnectionCount(), "should be 2 idle connections in pool") @@ -1126,18 +1126,18 @@ func TestPool(t *testing.T) { MaintainInterval: 10 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) // Check out and check in 3 connections. Assert that there are 3 total and 3 idle // connections in the pool. conns := make([]*connection, 3) for i := range conns { conns[i], err = p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) } for _, c := range conns { err = p.checkIn(c) - noerr(t, err) + require.NoError(t, err) } assert.Equalf(t, 3, d.lenopened(), "should have opened 3 connections") assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") @@ -1176,7 +1176,7 @@ func TestPool(t *testing.T) { MaintainInterval: 10 * time.Millisecond, }, WithDialer(func(Dialer) Dialer { return d })) err := p.ready() - noerr(t, err) + require.NoError(t, err) assertConnectionsOpened(t, d, 3) assert.Equalf(t, 3, p.availableConnectionCount(), "should be 3 idle connections in pool") assert.Equalf(t, 3, p.totalConnectionCount(), "should be 3 total connection in pool") diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index e23c604156..e798b16128 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -608,11 +608,11 @@ func TestServer(t *testing.T) { })) s.state = serverConnected err := s.pool.ready() - noerr(t, err) + require.NoError(t, err) defer s.pool.close(context.Background()) conn, err := s.Connection(context.Background()) - noerr(t, err) + require.NoError(t, err) if d.lenopened() != 1 { t.Errorf("Should have opened 1 connections, but didn't. got %d; want %d", d.lenopened(), 1) } @@ -634,7 +634,7 @@ func TestServer(t *testing.T) { <-ch runtime.Gosched() err = conn.Close() - noerr(t, err) + require.NoError(t, err) wg.Wait() close(cleanup) }) diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index c7dc7336e9..1831a16e72 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -17,6 +17,7 @@ import ( "time" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/description" ) @@ -28,7 +29,7 @@ func TestTopologyErrors(t *testing.T) { t.Run("errors are wrapped", func(t *testing.T) { t.Run("server selection error", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) desc := description.Topology{ diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index 18c119876f..2683afe1d7 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -33,14 +33,6 @@ import ( const testTimeout = 2 * time.Second -func noerr(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Errorf("Unexpected error: %v", err) - t.FailNow() - } -} - func compareErrors(err1, err2 error) bool { if err1 == nil && err2 == nil { return true @@ -74,7 +66,7 @@ func TestServerSelection(t *testing.T) { t.Run("Success", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Servers: []description.Server{ {Addr: address.Address("one"), Kind: description.Standalone}, @@ -87,7 +79,7 @@ func TestServerSelection(t *testing.T) { state := newServerSelectionState(selectFirst, nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) + require.NoError(t, err) if len(srvs) != 1 { t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1) } @@ -97,7 +89,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Compatibility Error Min Version Too High", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Kind: description.Single, Servers: []description.Server{ @@ -120,7 +112,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Compatibility Error Max Version Too Low", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{ Kind: description.Single, Servers: []description.Server{ @@ -143,7 +135,7 @@ func TestServerSelection(t *testing.T) { }) t.Run("Updated", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) desc := description.Topology{Servers: []description.Server{}} subCh := make(chan description.Topology, 1) subCh <- desc @@ -152,7 +144,7 @@ func TestServerSelection(t *testing.T) { go func() { state := newServerSelectionState(selectFirst, nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) + require.NoError(t, err) resp <- srvs }() @@ -192,7 +184,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -229,7 +221,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -265,7 +257,7 @@ func TestServerSelection(t *testing.T) { }, } topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) subCh := make(chan description.Topology, 1) subCh <- desc resp := make(chan error) @@ -288,10 +280,10 @@ func TestServerSelection(t *testing.T) { }) t.Run("findServer returns topology kind", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id) - noerr(t, err) + require.NoError(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) desc.Kind = description.Single @@ -300,14 +292,14 @@ func TestServerSelection(t *testing.T) { selected := description.Server{Addr: address.Address("one")} ss, err := topo.FindServer(selected) - noerr(t, err) + require.NoError(t, err) if ss.Kind != description.Single { t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.Single) } }) t.Run("Update on not primary error", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) addr1 := address.Address("one") @@ -324,7 +316,7 @@ func TestServerSelection(t *testing.T) { // manually add the servers to the topology for _, srv := range desc.Servers { s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) - noerr(t, err) + require.NoError(t, err) topo.servers[srv.Addr] = s } @@ -342,7 +334,7 @@ func TestServerSelection(t *testing.T) { // send a not primary error to the server forcing an update serv, err := topo.FindServer(desc.Servers[0]) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&serv.state, serverConnected) _ = serv.ProcessError(driver.Error{Message: driver.LegacyNotPrimaryErrMsg}, initConnection{}) @@ -352,7 +344,7 @@ func TestServerSelection(t *testing.T) { // server selection should discover the new topology state := newServerSelectionState(description.WriteSelector(), nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) - noerr(t, err) + require.NoError(t, err) resp <- srvs }() @@ -373,7 +365,7 @@ func TestServerSelection(t *testing.T) { t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) { // Assert that the server selection fast path does not create a Subscription or check for timeout errors. topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) primaryAddr := address.Address("one") @@ -385,7 +377,7 @@ func TestServerSelection(t *testing.T) { topo.desc.Store(desc) for _, srv := range desc.Servers { s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id) - noerr(t, err) + require.NoError(t, err) topo.servers[srv.Addr] = s } @@ -395,13 +387,13 @@ func TestServerSelection(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() selectedServer, err := topo.SelectServer(ctx, description.WriteSelector()) - noerr(t, err) + require.NoError(t, err) selectedAddr := selectedServer.(*SelectedServer).address assert.Equal(t, primaryAddr, selectedAddr, "expected address %v, got %v", primaryAddr, selectedAddr) }) t.Run("default to selecting from subscription if fast path fails", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) atomic.StoreInt64(&topo.state, topologyConnected) desc := description.Topology{ @@ -420,7 +412,7 @@ func TestSessionTimeout(t *testing.T) { t.Run("UpdateSessionTimeout", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.fsm.Servers = []description.Server{ { @@ -449,7 +441,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("MultipleUpdates", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.fsm.Kind = description.ReplicaSetWithPrimary topo.servers["foo"] = nil topo.servers["bar"] = nil @@ -496,7 +488,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("NoUpdate", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ @@ -542,7 +534,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("TimeoutDataBearing", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ @@ -588,7 +580,7 @@ func TestSessionTimeout(t *testing.T) { }) t.Run("MixedSessionSupport", func(t *testing.T) { topo, err := New(nil) - noerr(t, err) + require.NoError(t, err) topo.fsm.Kind = description.ReplicaSetWithPrimary topo.servers["one"] = nil topo.servers["two"] = nil From c54db5e701f6f5a3f1c58de229fe6106fe3bbb25 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Fri, 27 Sep 2024 18:23:20 -0700 Subject: [PATCH 3/4] Fix missing noerr function in new pool tests. --- x/mongo/driver/topology/pool_test.go | 50 ++++++++++++++-------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 41aeae264b..0f8a5a0570 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -1227,7 +1227,7 @@ func TestBackgroundRead(t *testing.T) { }() _, err := nc.Write([]byte{10, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1235,10 +1235,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1267,7 +1267,7 @@ func TestBackgroundRead(t *testing.T) { // Wait until the operation times out, then write an full message. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1275,10 +1275,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1287,7 +1287,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1314,7 +1314,7 @@ func TestBackgroundRead(t *testing.T) { // Wait until the operation times out, then write an incomplete head. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1322,10 +1322,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1334,7 +1334,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1366,7 +1366,7 @@ func TestBackgroundRead(t *testing.T) { // message. time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1374,10 +1374,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1386,7 +1386,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1416,11 +1416,11 @@ func TestBackgroundRead(t *testing.T) { var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) + require.NoError(t, err) time.Sleep(timeout * 2) // write a complete message _, err = nc.Write([]byte{2, 3, 4}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1428,10 +1428,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1440,7 +1440,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: @@ -1466,11 +1466,11 @@ func TestBackgroundRead(t *testing.T) { var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) + require.NoError(t, err) time.Sleep(timeout * 2) // write an incomplete message _, err = nc.Write([]byte{2}) - noerr(t, err) + require.NoError(t, err) }) p := newPool( @@ -1478,10 +1478,10 @@ func TestBackgroundRead(t *testing.T) { ) defer p.close(context.Background()) err := p.ready() - noerr(t, err) + require.NoError(t, err) conn, err := p.checkOut(context.Background()) - noerr(t, err) + require.NoError(t, err) ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) @@ -1490,7 +1490,7 @@ func TestBackgroundRead(t *testing.T) { ) assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) - noerr(t, err) + require.NoError(t, err) var bgErrs []error select { case bgErrs = <-errsCh: From b83ad92be43628fc11f7142edfad43dda2257523 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Mon, 30 Sep 2024 18:42:51 -0700 Subject: [PATCH 4/4] Add unit tests for isAlive and add comments. --- x/mongo/driver/topology/connection.go | 10 +++ x/mongo/driver/topology/connection_test.go | 82 ++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 379b64b7bd..e00363a548 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -575,8 +575,18 @@ func (c *connection) closed() bool { // seconds. For frequently in-use connections, a network error during an // operation will be the first indication of a dead connection. func (c *connection) isAlive() bool { + if c.nc == nil { + return false + } + // If the connection has been idle for less than 10 seconds, skip the // liveness check. + // + // The 10-seconds idle bypass is based on the liveness check implementation + // in the Python Driver. That implementation uses 1 second as the idle + // threshold, but we chose to be more conservative in the Go Driver because + // this is new behavior with unknown side-effects. See + // https://github.com/mongodb/mongo-python-driver/blob/e6b95f65953e01e435004af069a6976473eaf841/pymongo/synchronous/pool.py#L983-L985 idleStart, ok := c.idleStart.Load().(time.Time) if !ok || idleStart.Add(10*time.Second).After(time.Now()) { return true diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 4a30dc6fa4..07dd9ff0ec 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -1252,3 +1252,85 @@ func (tcl *testCancellationListener) assertCalledOnce(t *testing.T) { assert.Equal(t, 1, tcl.numListen, "expected Listen to be called once, got %d", tcl.numListen) assert.Equal(t, 1, tcl.numStopListening, "expected StopListening to be called once, got %d", tcl.numListen) } + +func TestConnection_IsAlive(t *testing.T) { + t.Parallel() + + t.Run("uninitialized", func(t *testing.T) { + t.Parallel() + + conn := newConnection("") + assert.False(t, + conn.isAlive(), + "expected isAlive for an uninitialized connection to always return false") + }) + + t.Run("connection open", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + // Keep the connection open until the end of the test. + <-cleanup + _ = nc.Close() + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.True(t, + conn.isAlive(), + "expected isAlive for an open connection to return true") + }) + + t.Run("connection closed", func(t *testing.T) { + t.Parallel() + + conns := make(chan net.Conn) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + conns <- nc + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + // Close the connection before calling isAlive. + nc := <-conns + err = nc.Close() + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.False(t, + conn.isAlive(), + "expected isAlive for a closed connection to return false") + }) + + t.Run("connection reads data", func(t *testing.T) { + t.Parallel() + + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + // Write some data to the connection before calling isAlive. + _, err := nc.Write([]byte{5, 0, 0, 0, 0}) + require.NoError(t, err) + + // Keep the connection open until the end of the test. + <-cleanup + _ = nc.Close() + }) + + conn := newConnection(address.Address(addr.String())) + err := conn.connect(context.Background()) + require.NoError(t, err) + + conn.idleStart.Store(time.Now().Add(-11 * time.Second)) + assert.False(t, + conn.isAlive(), + "expected isAlive for an open connection that reads data to return false") + }) +}