From 93d8ee1327955704eadfa47d0801f46084c1fb3e Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 14 Aug 2024 18:14:54 -0400 Subject: [PATCH 1/7] GODRIVER-3302 Handle malformatted message length properly. --- x/mongo/driver/topology/connection.go | 4 ++++ x/mongo/driver/topology/connection_test.go | 17 +++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 49a613aef8..093531f538 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -476,6 +476,10 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // read the length as an int32 size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) + if size < 4 { + err = fmt.Errorf("malformatted message length: %d", size) + return nil, err.Error(), err + } // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded // defaultMaxMessageSize instead. maxMessageSize := c.desc.MaxMessageSize diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 946f74d8f2..e7247969ab 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -546,6 +546,23 @@ func TestConnection(t *testing.T) { } listener.assertCalledOnce(t) }) + t.Run("size too small errors", func(t *testing.T) { + err := errors.New("malformatted message length: 3") + tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}} + conn := &connection{id: "foobar", nc: tnc, state: connConnected} + listener := newTestCancellationListener(false) + conn.cancellationListener = listener + + want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()} + _, got := conn.readWireMessage(context.Background()) + if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { + t.Errorf("errors do not match. got %v; want %v", got, want) + } + if !tnc.closed { + t.Errorf("failed to closeConnection net.Conn after error writing bytes.") + } + listener.assertCalledOnce(t) + }) t.Run("full message read errors", func(t *testing.T) { err := errors.New("Read error") tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}} From be66ac65223209a56d915dc386d326f45f299bad Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 15 Aug 2024 17:13:26 -0400 Subject: [PATCH 2/7] fix spelling --- x/mongo/driver/topology/connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 093531f538..44b27ac2d6 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -477,7 +477,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) if size < 4 { - err = fmt.Errorf("malformatted message length: %d", size) + err = fmt.Errorf("malformed message length: %d", size) return nil, err.Error(), err } // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded From ddbd3e9c62e10d22eb3efe367b403964c82865c7 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 20 Aug 2024 15:11:16 -0400 Subject: [PATCH 3/7] update the connection pool background read logic --- x/mongo/driver/topology/connection.go | 36 ++-- x/mongo/driver/topology/connection_test.go | 2 +- x/mongo/driver/topology/pool.go | 78 ++++--- x/mongo/driver/topology/pool_test.go | 224 +++++++++++++++++++++ 4 files changed, 295 insertions(+), 45 deletions(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 44b27ac2d6..b99ea53536 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -79,9 +79,9 @@ type connection struct { driverConnectionID uint64 generation uint64 - // awaitingResponse indicates that the server response was not completely + // awaitingResponse indicates the size of server response that was not completely // read before returning the connection to the pool. - awaitingResponse bool + awaitingResponse *int32 // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate // accessTokens in the OIDC authenticator cache. @@ -423,15 +423,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { dst, errMsg, err := c.read(ctx) if err != nil { - if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) { - // If the error was a timeout error and CSOT is enabled, instead of - // closing the connection mark it as awaiting response so the pool - // can read the response before making it available to other - // operations. - c.awaitingResponse = true - } else { - // Otherwise, use the pre-CSOT behavior and close the connection - // because we don't know if there are other bytes left to read. + if c.awaitingResponse == nil { + // If the connection was not marked as awaiting response, use the + // pre-CSOT behavior and close the connection because we don't know + // if there are other bytes left to read. c.close() } message := errMsg @@ -461,6 +456,15 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, } }() + needToWait := func(err error) bool { + // If the error was a timeout error and CSOT is enabled, instead of + // closing the connection mark it as awaiting response so the pool + // can read the response before making it available to other + // operations. + nerr := net.Error(nil) + return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) + } + // We use an array here because it only costs 4 bytes on the stack and means we'll only need to // reslice dst once instead of twice. var sizeBuf [4]byte @@ -468,8 +472,11 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst // because there might be more than one wire message waiting to be read, for example when // reading messages from an exhaust cursor. - _, err = io.ReadFull(c.nc, sizeBuf[:]) + n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { + if l := int32(n); l == 0 && needToWait(err) { + c.awaitingResponse = &l + } return nil, "incomplete read of message header", err } @@ -493,8 +500,11 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, dst := make([]byte, size) copy(dst, sizeBuf[:]) - _, err = io.ReadFull(c.nc, dst[4:]) + n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { + if l := size - 4 - int32(n); l > 0 && needToWait(err) { + c.awaitingResponse = &l + } return dst, "incomplete read of full message", err } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index e7247969ab..0521862157 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -547,7 +547,7 @@ func TestConnection(t *testing.T) { listener.assertCalledOnce(t) }) t.Run("size too small errors", func(t *testing.T) { - err := errors.New("malformatted message length: 3") + err := errors.New("malformed message length: 3") tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}} conn := &connection{id: "foobar", nc: tnc, state: connConnected} listener := newTestCancellationListener(false) diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 52461eb681..b4001cb17a 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -9,6 +9,8 @@ package topology import ( "context" "fmt" + "io" + "io/ioutil" "net" "sync" "sync/atomic" @@ -788,17 +790,27 @@ var ( // // It calls the package-global BGReadCallback function, if set, with the // address, timings, and any errors that occurred. -func bgRead(pool *pool, conn *connection) { - var start, read time.Time - start = time.Now() - errs := make([]error, 0) - connClosed := false +func bgRead(pool *pool, conn *connection, size int32) { + var err error + start := time.Now() defer func() { + read := time.Now() + errs := make([]error, 0) + connClosed := false + if err != nil { + errs = append(errs, err) + connClosed = true + err = conn.close() + if err != nil { + errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) + } + } + // No matter what happens, always check the connection back into the // pool, which will either make it available for other operations or // remove it from the pool if it was closed. - err := pool.checkInNoEvent(conn) + err = pool.checkInNoEvent(conn) if err != nil { errs = append(errs, fmt.Errorf("error checking in: %w", err)) } @@ -808,34 +820,37 @@ func bgRead(pool *pool, conn *connection) { } }() - err := conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) + err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) if err != nil { - errs = append(errs, fmt.Errorf("error setting a read deadline: %w", err)) - - connClosed = true - err := conn.close() - if err != nil { - errs = append(errs, fmt.Errorf("error closing conn after setting read deadline: %w", err)) - } - + err = fmt.Errorf("error setting a read deadline: %w", err) return } - // The context here is only used for cancellation, not deadline timeout, so - // use context.Background(). The read timeout is set by calling - // SetReadDeadline above. - _, _, err = conn.read(context.Background()) - read = time.Now() - if err != nil { - errs = append(errs, fmt.Errorf("error reading: %w", err)) - - connClosed = true - err := conn.close() + if size == 0 { + var sizeBuf [4]byte + _, err = io.ReadFull(conn.nc, sizeBuf[:]) if err != nil { - errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) + err = fmt.Errorf("error reading the message size: %w", err) + return } - - return + size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) + if size < 4 { + err = fmt.Errorf("malformed message length: %d", size) + return + } + maxMessageSize := conn.desc.MaxMessageSize + if maxMessageSize == 0 { + maxMessageSize = defaultMaxMessageSize + } + if uint32(size) > maxMessageSize { + err = errResponseTooLarge + return + } + size -= 4 + } + _, err = io.CopyN(ioutil.Discard, conn.nc, int64(size)) + if err != nil { + err = fmt.Errorf("error reading message of %d: %w", size, err) } } @@ -886,9 +901,10 @@ func (p *pool) checkInNoEvent(conn *connection) error { // means that connections in "awaiting response" state are checked in but // not usable, which is not covered by the current pool events. We may need // to add pool event information in the future to communicate that. - if conn.awaitingResponse { - conn.awaitingResponse = false - go bgRead(p, conn) + if conn.awaitingResponse != nil { + size := *conn.awaitingResponse + conn.awaitingResponse = nil + go bgRead(p, conn, size) return nil } diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index bc7115ee2c..35185e954c 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -9,13 +9,17 @@ package topology import ( "context" "errors" + "io" "net" + "os" + "regexp" "sync" "testing" "time" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/eventtest" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" @@ -1122,6 +1126,226 @@ func TestPool(t *testing.T) { p.close(context.Background()) }) }) + t.Run("bgRead", func(t *testing.T) { + t.Parallel() + + var errCh chan error + BGReadCallback = func(addr string, start, read time.Time, errs []error, connClosed bool) { + defer close(errCh) + + for _, err := range errs { + errCh <- err + } + } + + const sockPath = "./test.sock" + + var socket net.Listener + + setup := func(t *testing.T) { + t.Helper() + + errCh = make(chan error) + + var err error + socket, err = net.Listen("unix", sockPath) + noerr(t, err) + } + teardown := func(t *testing.T) { + t.Helper() + + os.Remove(sockPath) + } + + t.Run("incomplete read of message header", func(t *testing.T) { + setup(t) + defer teardown(t) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func(t *testing.T) { + t.Helper() + + defer wg.Done() + + conn, err := socket.Accept() + noerr(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{10, 0, 0}) + noerr(t, err) + time.Sleep(1500 * time.Millisecond) + }(t) + + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }) + }), + ) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + assert.Nil(t, conn.awaitingResponse, "conn.awaitingResponse should be nil") + wg.Wait() + p.close(context.Background()) + close(errCh) + }) + t.Run("timeout on reading the message header", func(t *testing.T) { + setup(t) + defer teardown(t) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func(t *testing.T) { + t.Helper() + + defer wg.Done() + + conn, err := socket.Accept() + noerr(t, err) + defer conn.Close() + + time.Sleep(1500 * time.Millisecond) + _, err = conn.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) + noerr(t, err) + time.Sleep(1500 * time.Millisecond) + }(t) + + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }) + }), + ) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + err = p.checkIn(conn) + noerr(t, err) + wg.Wait() + p.close(context.Background()) + errs := []*regexp.Regexp{ + regexp.MustCompile( + `^error reading message of 6: read unix .*->\.\/test.sock: i\/o timeout$`, + ), + } + for i := 0; true; i++ { + err, ok := <-errCh + if !ok { + if i != len(errs) { + assert.Fail(t, "expected more errors") + } + break + } else if i < len(errs) { + assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err) + } else { + assert.Fail(t, "unexpected error", "got unexpected error: %v", err) + } + } + }) + t.Run("timeout on reading the full message", func(t *testing.T) { + setup(t) + defer teardown(t) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func(t *testing.T) { + t.Helper() + + defer wg.Done() + + conn, err := socket.Accept() + noerr(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) + noerr(t, err) + time.Sleep(1500 * time.Millisecond) + _, err = conn.Write([]byte{2, 3, 4}) + noerr(t, err) + time.Sleep(1500 * time.Millisecond) + }(t) + + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + conn, err := net.Dial("unix", sockPath) + noerr(t, err) + return newLimitConn(conn, 10), nil + }) + }), + ) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + err = p.checkIn(conn) + noerr(t, err) + wg.Wait() + p.close(context.Background()) + errs := []string{ + "error reading message of 3: EOF", + } + for i := 0; true; i++ { + err, ok := <-errCh + if !ok { + if i != len(errs) { + assert.Fail(t, "expected more errors") + } + break + } else if i < len(errs) { + assert.EqualError(t, err, errs[i], "mismatched err: %v", err) + } else { + assert.Fail(t, "unexpected error", "got unexpected error: %v", err) + } + } + }) + }) +} + +type limitConn struct { + net.Conn + r io.Reader +} + +func newLimitConn(conn net.Conn, n int64) limitConn { + return limitConn{conn, io.LimitReader(conn, n)} +} + +func (lc limitConn) Read(b []byte) (n int, err error) { + return lc.r.Read(b) } func assertConnectionsClosed(t *testing.T, dialer *dialer, count int) { From b565a23108ee970b71368ed551412ccd5f21d9a8 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 23 Aug 2024 17:56:43 -0400 Subject: [PATCH 4/7] code cleanup --- x/mongo/driver/topology/connection.go | 161 ++++++++++++++------------ x/mongo/driver/topology/pool.go | 22 +--- x/mongo/driver/topology/pool_test.go | 2 +- 3 files changed, 91 insertions(+), 94 deletions(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index b99ea53536..d0dfe08789 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -79,9 +79,9 @@ type connection struct { driverConnectionID uint64 generation uint64 - // awaitingResponse indicates the size of server response that was not completely + // awaitRemainingBytes indicates the size of server response that was not completely // read before returning the connection to the pool. - awaitingResponse *int32 + awaitRemainingBytes *int32 // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate // accessTokens in the OIDC authenticator cache. @@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { return c } -// DriverConnectionID returns the driver connection ID. -// TODO(GODRIVER-2824): change return type to int64. -func (c *connection) DriverConnectionID() uint64 { - return c.driverConnectionID -} - // setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection // configuration. func (c *connection) setGenerationNumber() { @@ -142,6 +136,39 @@ func (c *connection) hasGenerationNumber() bool { return c.desc.LoadBalanced() } +func configureTLS(ctx context.Context, + tlsConnSource tlsConnectionSource, + nc net.Conn, + addr address.Address, + config *tls.Config, + ocspOpts *ocsp.VerifyOptions, +) (net.Conn, error) { + // Ensure config.ServerName is always set for SNI. + if config.ServerName == "" { + hostname := addr.String() + colonPos := strings.LastIndex(hostname, ":") + if colonPos == -1 { + colonPos = len(hostname) + } + + hostname = hostname[:colonPos] + config.ServerName = hostname + } + + client := tlsConnSource.Client(nc, config) + if err := clientHandshake(ctx, client); err != nil { + return nil, err + } + + // Only do OCSP verification if TLS verification is requested. + if !config.InsecureSkipVerify { + if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil { + return nil, ocspErr + } + } + return client, nil +} + // connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization // handshakes. All errors returned by connect are considered "before the handshake completes" and // must be handled by calling the appropriate SDAM handshake error handler. @@ -317,6 +344,10 @@ func (c *connection) closeConnectContext() { } } +func (c *connection) cancellationListenerCallback() { + _ = c.close() +} + func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error { if originalError == nil { return nil @@ -339,10 +370,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead return originalError } -func (c *connection) cancellationListenerCallback() { - _ = c.close() -} - func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { var err error if atomic.LoadInt64(&c.state) != connConnected { @@ -423,7 +450,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { dst, errMsg, err := c.read(ctx) if err != nil { - if c.awaitingResponse == nil { + if c.awaitRemainingBytes == nil { // If the connection was not marked as awaiting response, use the // pre-CSOT behavior and close the connection because we don't know // if there are other bytes left to read. @@ -443,6 +470,29 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { return dst, nil } +func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) { + // read the length as an int32 + size := (int32(wmSizeBytes[0])) | + (int32(wmSizeBytes[1]) << 8) | + (int32(wmSizeBytes[2]) << 16) | + (int32(wmSizeBytes[3]) << 24) + + if size < 4 { + return 0, fmt.Errorf("malformed message length: %d", size) + } + // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded + // defaultMaxMessageSize instead. + maxMessageSize := c.desc.MaxMessageSize + if maxMessageSize == 0 { + maxMessageSize = defaultMaxMessageSize + } + if uint32(size) > maxMessageSize { + return 0, errResponseTooLarge + } + + return size, nil +} + func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) { go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback) defer func() { @@ -475,35 +525,23 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { if l := int32(n); l == 0 && needToWait(err) { - c.awaitingResponse = &l + c.awaitRemainingBytes = &l } return nil, "incomplete read of message header", err } - - // read the length as an int32 - size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) - - if size < 4 { - err = fmt.Errorf("malformed message length: %d", size) + size, err := c.parseWmSizeBytes(sizeBuf) + if err != nil { return nil, err.Error(), err } - // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded - // defaultMaxMessageSize instead. - maxMessageSize := c.desc.MaxMessageSize - if maxMessageSize == 0 { - maxMessageSize = defaultMaxMessageSize - } - if uint32(size) > maxMessageSize { - return nil, errResponseTooLarge.Error(), errResponseTooLarge - } dst := make([]byte, size) copy(dst, sizeBuf[:]) n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { - if l := size - 4 - int32(n); l > 0 && needToWait(err) { - c.awaitingResponse = &l + remainingBytes := size - 4 - int32(n) + if remainingBytes > 0 && needToWait(err) { + c.awaitRemainingBytes = &remainingBytes } return dst, "incomplete read of full message", err } @@ -551,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) { c.canStream = canStream } -func (c initConnection) supportsStreaming() bool { - return c.canStream -} - func (c *connection) setStreaming(streaming bool) { c.currentlyStreaming = streaming } @@ -568,6 +602,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) { c.writeTimeout = timeout } +// DriverConnectionID returns the driver connection ID. +// TODO(GODRIVER-2824): change return type to int64. +func (c *connection) DriverConnectionID() uint64 { + return c.driverConnectionID +} + func (c *connection) ID() string { return c.id } @@ -576,6 +616,14 @@ func (c *connection) ServerConnectionID() *int64 { return c.serverConnectionID } +func (c *connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} + // initConnection is an adapter used during connection initialization. It has the minimum // functionality necessary to implement the driver.Connection interface, which is required to pass a // *connection to a Handshaker. @@ -613,7 +661,7 @@ func (c initConnection) CurrentlyStreaming() bool { return c.getCurrentlyStreaming() } func (c initConnection) SupportsStreaming() bool { - return c.supportsStreaming() + return c.canStream } // Connection implements the driver.Connection interface to allow reading and writing wire @@ -847,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 { return c.connection.DriverConnectionID() } -func configureTLS(ctx context.Context, - tlsConnSource tlsConnectionSource, - nc net.Conn, - addr address.Address, - config *tls.Config, - ocspOpts *ocsp.VerifyOptions, -) (net.Conn, error) { - // Ensure config.ServerName is always set for SNI. - if config.ServerName == "" { - hostname := addr.String() - colonPos := strings.LastIndex(hostname, ":") - if colonPos == -1 { - colonPos = len(hostname) - } - - hostname = hostname[:colonPos] - config.ServerName = hostname - } - - client := tlsConnSource.Client(nc, config) - if err := clientHandshake(ctx, client); err != nil { - return nil, err - } - - // Only do OCSP verification if TLS verification is requested. - if !config.InsecureSkipVerify { - if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil { - return nil, ocspErr - } - } - return client, nil -} - // OIDCTokenGenID returns the OIDC token generation ID. func (c *Connection) OIDCTokenGenID() uint64 { return c.oidcTokenGenID @@ -933,11 +948,3 @@ func (c *cancellListener) StopListening() bool { c.done <- struct{}{} return c.aborted } - -func (c *connection) OIDCTokenGenID() uint64 { - return c.oidcTokenGenID -} - -func (c *connection) SetOIDCTokenGenID(genID uint64) { - c.oidcTokenGenID = genID -} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index b4001cb17a..5d232f1ebc 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "net" "sync" "sync/atomic" @@ -833,22 +832,13 @@ func bgRead(pool *pool, conn *connection, size int32) { err = fmt.Errorf("error reading the message size: %w", err) return } - size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) - if size < 4 { - err = fmt.Errorf("malformed message length: %d", size) - return - } - maxMessageSize := conn.desc.MaxMessageSize - if maxMessageSize == 0 { - maxMessageSize = defaultMaxMessageSize - } - if uint32(size) > maxMessageSize { - err = errResponseTooLarge + size, err = conn.parseWmSizeBytes(sizeBuf) + if err != nil { return } size -= 4 } - _, err = io.CopyN(ioutil.Discard, conn.nc, int64(size)) + _, err = io.CopyN(io.Discard, conn.nc, int64(size)) if err != nil { err = fmt.Errorf("error reading message of %d: %w", size, err) } @@ -901,9 +891,9 @@ func (p *pool) checkInNoEvent(conn *connection) error { // means that connections in "awaiting response" state are checked in but // not usable, which is not covered by the current pool events. We may need // to add pool event information in the future to communicate that. - if conn.awaitingResponse != nil { - size := *conn.awaitingResponse - conn.awaitingResponse = nil + if conn.awaitRemainingBytes != nil { + size := *conn.awaitRemainingBytes + conn.awaitRemainingBytes = nil go bgRead(p, conn, size) return nil } diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 35185e954c..ebb342e17c 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -1197,7 +1197,7 @@ func TestPool(t *testing.T) { `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, ) assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - assert.Nil(t, conn.awaitingResponse, "conn.awaitingResponse should be nil") + assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitingResponse should be nil") wg.Wait() p.close(context.Background()) close(errCh) From 2d988ed1950ffb6cc84e285dea41f36f8af22eea Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 13 Sep 2024 11:10:01 -0400 Subject: [PATCH 5/7] improvements --- x/mongo/driver/topology/connection.go | 12 +- x/mongo/driver/topology/pool.go | 2 +- x/mongo/driver/topology/pool_test.go | 344 +++++++++++++------------- 3 files changed, 175 insertions(+), 183 deletions(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index d0dfe08789..7a8427ccee 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -9,6 +9,7 @@ package topology import ( "context" "crypto/tls" + "encoding/binary" "errors" "fmt" "io" @@ -472,10 +473,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) { // read the length as an int32 - size := (int32(wmSizeBytes[0])) | - (int32(wmSizeBytes[1]) << 8) | - (int32(wmSizeBytes[2]) << 16) | - (int32(wmSizeBytes[3]) << 24) + size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:])) if size < 4 { return 0, fmt.Errorf("malformed message length: %d", size) @@ -506,7 +504,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, } }() - needToWait := func(err error) bool { + isCSOTTimeout := func(err error) bool { // If the error was a timeout error and CSOT is enabled, instead of // closing the connection mark it as awaiting response so the pool // can read the response before making it available to other @@ -524,7 +522,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // reading messages from an exhaust cursor. n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { - if l := int32(n); l == 0 && needToWait(err) { + if l := int32(n); l == 0 && isCSOTTimeout(err) { c.awaitRemainingBytes = &l } return nil, "incomplete read of message header", err @@ -540,7 +538,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { remainingBytes := size - 4 - int32(n) - if remainingBytes > 0 && needToWait(err) { + if remainingBytes > 0 && isCSOTTimeout(err) { c.awaitRemainingBytes = &remainingBytes } return dst, "incomplete read of full message", err diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 5d232f1ebc..ddb69ada76 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -840,7 +840,7 @@ func bgRead(pool *pool, conn *connection, size int32) { } _, err = io.CopyN(io.Discard, conn.nc, int64(size)) if err != nil { - err = fmt.Errorf("error reading message of %d: %w", size, err) + err = fmt.Errorf("error discarding %d byte message: %w", size, err) } } diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index ebb342e17c..514d393a93 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -11,7 +11,6 @@ import ( "errors" "io" "net" - "os" "regexp" "sync" "testing" @@ -1126,212 +1125,207 @@ func TestPool(t *testing.T) { p.close(context.Background()) }) }) - t.Run("bgRead", func(t *testing.T) { - t.Parallel() +} + +func TestBackgroundRead(t *testing.T) { + t.Parallel() - var errCh chan error - BGReadCallback = func(addr string, start, read time.Time, errs []error, connClosed bool) { + newBGReadCallback := func(errCh chan error) func(string, time.Time, time.Time, []error, bool) { + return func(_ string, _, _ time.Time, errs []error, _ bool) { defer close(errCh) for _, err := range errs { errCh <- err } } + } - const sockPath = "./test.sock" - - var socket net.Listener + t.Run("incomplete read of message header", func(t *testing.T) { + errCh := make(chan error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) - setup := func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { t.Helper() - errCh = make(chan error) + defer func() { + _ = nc.Close() + wg.Done() + }() - var err error - socket, err = net.Listen("unix", sockPath) + _, err := nc.Write([]byte{10, 0, 0}) noerr(t, err) - } - teardown := func(t *testing.T) { - t.Helper() - - os.Remove(sockPath) - } - - t.Run("incomplete read of message header", func(t *testing.T) { - setup(t) - defer teardown(t) - - wg := &sync.WaitGroup{} - wg.Add(1) - go func(t *testing.T) { - t.Helper() + time.Sleep(1500 * time.Millisecond) + }) - defer wg.Done() + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Dial("tcp", addr.String()) + }) + }), + ) + err := p.ready() + noerr(t, err) - conn, err := socket.Accept() - noerr(t, err) - defer conn.Close() + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") + wg.Wait() + p.close(context.Background()) + close(errCh) + }) + t.Run("timeout on reading the message header", func(t *testing.T) { + errCh := make(chan error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) - _, err = conn.Write([]byte{10, 0, 0}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - }(t) + wg := &sync.WaitGroup{} + wg.Add(1) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + t.Helper() - p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", sockPath) - }) - }), - ) - err := p.ready() - noerr(t, err) + defer func() { + _ = nc.Close() + wg.Done() + }() - conn, err := p.checkOut(context.Background()) + time.Sleep(1500 * time.Millisecond) + _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) - defer cancel() - _, err = conn.readWireMessage(ctx) - regex := regexp.MustCompile( - `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, - ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitingResponse should be nil") - wg.Wait() - p.close(context.Background()) - close(errCh) + time.Sleep(1500 * time.Millisecond) }) - t.Run("timeout on reading the message header", func(t *testing.T) { - setup(t) - defer teardown(t) + go func(t *testing.T) { + }(t) - wg := &sync.WaitGroup{} - wg.Add(1) - go func(t *testing.T) { - t.Helper() - - defer wg.Done() - - conn, err := socket.Accept() - noerr(t, err) - defer conn.Close() - - time.Sleep(1500 * time.Millisecond) - _, err = conn.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - }(t) - - p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", sockPath) - }) - }), - ) - err := p.ready() - noerr(t, err) + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Dial("tcp", addr.String()) + }) + }), + ) + err := p.ready() + noerr(t, err) - conn, err := p.checkOut(context.Background()) - noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) - defer cancel() - _, err = conn.readWireMessage(ctx) - regex := regexp.MustCompile( - `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, - ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - err = p.checkIn(conn) - noerr(t, err) - wg.Wait() - p.close(context.Background()) - errs := []*regexp.Regexp{ - regexp.MustCompile( - `^error reading message of 6: read unix .*->\.\/test.sock: i\/o timeout$`, - ), - } - for i := 0; true; i++ { - err, ok := <-errCh - if !ok { - if i != len(errs) { - assert.Fail(t, "expected more errors") - } - break - } else if i < len(errs) { - assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err) - } else { - assert.Fail(t, "unexpected error", "got unexpected error: %v", err) + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + err = p.checkIn(conn) + noerr(t, err) + wg.Wait() + p.close(context.Background()) + errs := []*regexp.Regexp{ + regexp.MustCompile( + `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ), + } + for i := 0; true; i++ { + err, ok := <-errCh + if !ok { + if i != len(errs) { + assert.Fail(t, "expected more errors") } + break + } else if i < len(errs) { + assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err) + } else { + assert.Fail(t, "unexpected error", "got unexpected error: %v", err) } + } + }) + t.Run("timeout on reading the full message", func(t *testing.T) { + errCh := make(chan error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + t.Cleanup(func() { + BGReadCallback = originalCallback }) - t.Run("timeout on reading the full message", func(t *testing.T) { - setup(t) - defer teardown(t) - - wg := &sync.WaitGroup{} - wg.Add(1) - go func(t *testing.T) { - t.Helper() - - defer wg.Done() - conn, err := socket.Accept() - noerr(t, err) - defer conn.Close() - - _, err = conn.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - _, err = conn.Write([]byte{2, 3, 4}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - }(t) + wg := &sync.WaitGroup{} + wg.Add(1) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + t.Helper() - p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - conn, err := net.Dial("unix", sockPath) - noerr(t, err) - return newLimitConn(conn, 10), nil - }) - }), - ) - err := p.ready() - noerr(t, err) + defer func() { + _ = nc.Close() + wg.Done() + }() - conn, err := p.checkOut(context.Background()) + var err error + _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second) - defer cancel() - _, err = conn.readWireMessage(ctx) - regex := regexp.MustCompile( - `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, - ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - err = p.checkIn(conn) + time.Sleep(1500 * time.Millisecond) + _, err = nc.Write([]byte{2, 3, 4}) noerr(t, err) - wg.Wait() - p.close(context.Background()) - errs := []string{ - "error reading message of 3: EOF", - } - for i := 0; true; i++ { - err, ok := <-errCh - if !ok { - if i != len(errs) { - assert.Fail(t, "expected more errors") - } - break - } else if i < len(errs) { - assert.EqualError(t, err, errs[i], "mismatched err: %v", err) - } else { - assert.Fail(t, "unexpected error", "got unexpected error: %v", err) + time.Sleep(1500 * time.Millisecond) + }) + + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr.String()) + noerr(t, err) + return newLimitConn(conn, 10), nil + }) + }), + ) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + err = p.checkIn(conn) + noerr(t, err) + wg.Wait() + p.close(context.Background()) + errs := []string{ + "error discarding 3 byte message: EOF", + } + for i := 0; true; i++ { + err, ok := <-errCh + if !ok { + if i != len(errs) { + assert.Fail(t, "expected more errors") } + break + } else if i < len(errs) { + assert.EqualError(t, err, errs[i], "mismatched err: %v", err) + } else { + assert.Fail(t, "unexpected error", "got unexpected error: %v", err) } - }) + } }) } From 9bea0eba3ecf7f07af412ac2a0c0e7c6e5ab2108 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 13 Sep 2024 18:07:20 -0400 Subject: [PATCH 6/7] improve test logic --- x/mongo/driver/topology/pool_test.go | 123 ++++++++++++--------------- 1 file changed, 56 insertions(+), 67 deletions(-) diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 514d393a93..29e4fce71f 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -1130,37 +1130,33 @@ func TestPool(t *testing.T) { func TestBackgroundRead(t *testing.T) { t.Parallel() - newBGReadCallback := func(errCh chan error) func(string, time.Time, time.Time, []error, bool) { + newBGReadCallback := func(errsCh chan []error) func(string, time.Time, time.Time, []error, bool) { return func(_ string, _, _ time.Time, errs []error, _ bool) { - defer close(errCh) - - for _, err := range errs { - errCh <- err - } + errsCh <- errs + close(errsCh) } } t.Run("incomplete read of message header", func(t *testing.T) { - errCh := make(chan error) + errsCh := make(chan []error) var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) t.Cleanup(func() { BGReadCallback = originalCallback }) - wg := &sync.WaitGroup{} - wg.Add(1) - addr := bootstrapConnections(t, 1, func(nc net.Conn) { - t.Helper() + const timeout = 10 * time.Millisecond + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { defer func() { + <-cleanup _ = nc.Close() - wg.Done() }() _, err := nc.Write([]byte{10, 0, 0}) noerr(t, err) - time.Sleep(1500 * time.Millisecond) }) p := newPool( @@ -1171,48 +1167,44 @@ func TestBackgroundRead(t *testing.T) { }) }), ) + defer p.close(context.Background()) err := p.ready() noerr(t, err) conn, err := p.checkOut(context.Background()) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") - wg.Wait() - p.close(context.Background()) - close(errCh) + close(errsCh) // this line causes a double close if BGReadCallback is ever called. }) t.Run("timeout on reading the message header", func(t *testing.T) { - errCh := make(chan error) + errsCh := make(chan []error) var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) t.Cleanup(func() { BGReadCallback = originalCallback }) - wg := &sync.WaitGroup{} - wg.Add(1) - addr := bootstrapConnections(t, 1, func(nc net.Conn) { - t.Helper() + const timeout = 10 * time.Millisecond + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { defer func() { + <-cleanup _ = nc.Close() - wg.Done() }() - time.Sleep(1500 * time.Millisecond) + time.Sleep(timeout * 2) _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) noerr(t, err) - time.Sleep(1500 * time.Millisecond) }) - go func(t *testing.T) { - }(t) p := newPool( poolConfig{}, @@ -1222,66 +1214,64 @@ func TestBackgroundRead(t *testing.T) { }) }), ) + defer p.close(context.Background()) err := p.ready() noerr(t, err) conn, err := p.checkOut(context.Background()) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) noerr(t, err) - wg.Wait() - p.close(context.Background()) - errs := []*regexp.Regexp{ + wantErrs := []*regexp.Regexp{ regexp.MustCompile( `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ), } - for i := 0; true; i++ { - err, ok := <-errCh - if !ok { - if i != len(errs) { - assert.Fail(t, "expected more errors") - } - break - } else if i < len(errs) { - assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err) + var bgErrs []error + select { + case bgErrs = <-errsCh: + case <-time.After(3 * time.Second): + assert.Fail(t, "did not receive expected error after waiting for 3 seconds") + } + for i, err := range bgErrs { + if i < len(wantErrs) { + assert.True(t, wantErrs[i].MatchString(err.Error()), "error %q does not match pattern %q", err, wantErrs[i]) } else { assert.Fail(t, "unexpected error", "got unexpected error: %v", err) } } }) t.Run("timeout on reading the full message", func(t *testing.T) { - errCh := make(chan error) + errsCh := make(chan []error) var originalCallback func(string, time.Time, time.Time, []error, bool) - originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) t.Cleanup(func() { BGReadCallback = originalCallback }) - wg := &sync.WaitGroup{} - wg.Add(1) - addr := bootstrapConnections(t, 1, func(nc net.Conn) { - t.Helper() + const timeout = 10 * time.Millisecond + cleanup := make(chan struct{}) + defer close(cleanup) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { defer func() { + <-cleanup _ = nc.Close() - wg.Done() }() var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) noerr(t, err) - time.Sleep(1500 * time.Millisecond) + time.Sleep(timeout * 2) _, err = nc.Write([]byte{2, 3, 4}) noerr(t, err) - time.Sleep(1500 * time.Millisecond) }) p := newPool( @@ -1294,34 +1284,33 @@ func TestBackgroundRead(t *testing.T) { }) }), ) + defer p.close(context.Background()) err := p.ready() noerr(t, err) conn, err := p.checkOut(context.Background()) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() _, err = conn.readWireMessage(ctx) regex := regexp.MustCompile( `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) noerr(t, err) - wg.Wait() - p.close(context.Background()) - errs := []string{ + wantErrs := []string{ "error discarding 3 byte message: EOF", } - for i := 0; true; i++ { - err, ok := <-errCh - if !ok { - if i != len(errs) { - assert.Fail(t, "expected more errors") - } - break - } else if i < len(errs) { - assert.EqualError(t, err, errs[i], "mismatched err: %v", err) + var bgErrs []error + select { + case bgErrs = <-errsCh: + case <-time.After(3 * time.Second): + assert.Fail(t, "did not receive expected error after waiting for 3 seconds") + } + for i, err := range bgErrs { + if i < len(wantErrs) { + assert.EqualError(t, err, wantErrs[i], "mismatched err: %v", err) } else { assert.Fail(t, "unexpected error", "got unexpected error: %v", err) } From 39f3021665919fb89d947f0798f3dab01455322e Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 17 Sep 2024 11:55:07 -0400 Subject: [PATCH 7/7] improve tests --- x/mongo/driver/topology/pool_test.go | 215 +++++++++++++++++++-------- 1 file changed, 157 insertions(+), 58 deletions(-) diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 4bf968e132..e0265ae4c6 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -9,7 +9,6 @@ package topology import ( "context" "errors" - "io" "net" "regexp" "sync" @@ -1159,12 +1158,7 @@ func TestBackgroundRead(t *testing.T) { }) p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Dial("tcp", addr.String()) - }) - }), + poolConfig{Address: address.Address(addr.String())}, ) defer p.close(context.Background()) err := p.ready() @@ -1182,7 +1176,7 @@ func TestBackgroundRead(t *testing.T) { assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") close(errsCh) // this line causes a double close if BGReadCallback is ever called. }) - t.Run("timeout on reading the message header", func(t *testing.T) { + t.Run("timeout reading message header, successful background read", func(t *testing.T) { errsCh := make(chan []error) var originalCallback func(string, time.Time, time.Time, []error, bool) originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) @@ -1192,26 +1186,19 @@ func TestBackgroundRead(t *testing.T) { const timeout = 10 * time.Millisecond - cleanup := make(chan struct{}) - defer close(cleanup) addr := bootstrapConnections(t, 1, func(nc net.Conn) { defer func() { - <-cleanup _ = nc.Close() }() + // Wait until the operation times out, then write an full message. time.Sleep(timeout * 2) - _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) + _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0}) noerr(t, err) }) p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Dial("tcp", addr.String()) - }) - }), + poolConfig{Address: address.Address(addr.String())}, ) defer p.close(context.Background()) err := p.ready() @@ -1228,26 +1215,63 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) noerr(t, err) - wantErrs := []*regexp.Regexp{ - regexp.MustCompile( - `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, - ), - } var bgErrs []error select { case bgErrs = <-errsCh: case <-time.After(3 * time.Second): assert.Fail(t, "did not receive expected error after waiting for 3 seconds") } - for i, err := range bgErrs { - if i < len(wantErrs) { - assert.True(t, wantErrs[i].MatchString(err.Error()), "error %q does not match pattern %q", err, wantErrs[i]) - } else { - assert.Fail(t, "unexpected error", "got unexpected error: %v", err) - } + require.Len(t, bgErrs, 0, "expected no error from bgRead()") + }) + t.Run("timeout reading message header, incomplete head during background read", func(t *testing.T) { + errsCh := make(chan []error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) + + const timeout = 10 * time.Millisecond + + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + defer func() { + _ = nc.Close() + }() + + // Wait until the operation times out, then write an incomplete head. + time.Sleep(timeout * 2) + _, err := nc.Write([]byte{10, 0, 0}) + noerr(t, err) + }) + + p := newPool( + poolConfig{Address: address.Address(addr.String())}, + ) + defer p.close(context.Background()) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) + err = p.checkIn(conn) + noerr(t, err) + var bgErrs []error + select { + case bgErrs = <-errsCh: + case <-time.After(3 * time.Second): + assert.Fail(t, "did not receive expected error after waiting for 3 seconds") } + require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") + assert.EqualError(t, bgErrs[0], "error reading the message size: unexpected EOF") }) - t.Run("timeout on reading the full message", func(t *testing.T) { + t.Run("timeout reading message header, background read timeout", func(t *testing.T) { errsCh := make(chan []error) var originalCallback func(string, time.Time, time.Time, []error, bool) originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) @@ -1265,23 +1289,69 @@ func TestBackgroundRead(t *testing.T) { _ = nc.Close() }() + // Wait until the operation times out, then write an incomplete + // message. + time.Sleep(timeout * 2) + _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) + noerr(t, err) + }) + + p := newPool( + poolConfig{Address: address.Address(addr.String())}, + ) + defer p.close(context.Background()) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) + err = p.checkIn(conn) + noerr(t, err) + var bgErrs []error + select { + case bgErrs = <-errsCh: + case <-time.After(3 * time.Second): + assert.Fail(t, "did not receive expected error after waiting for 3 seconds") + } + require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") + wantErr := regexp.MustCompile( + `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, wantErr.MatchString(bgErrs[0].Error()), "error %q does not match pattern %q", bgErrs[0], wantErr) + }) + t.Run("timeout reading full message, successful background read", func(t *testing.T) { + errsCh := make(chan []error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) + + const timeout = 10 * time.Millisecond + + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + defer func() { + _ = nc.Close() + }() + var err error _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) noerr(t, err) time.Sleep(timeout * 2) + // write a complete message _, err = nc.Write([]byte{2, 3, 4}) noerr(t, err) }) p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - conn, err := net.Dial("tcp", addr.String()) - noerr(t, err) - return newLimitConn(conn, 10), nil - }) - }), + poolConfig{Address: address.Address(addr.String())}, ) defer p.close(context.Background()) err := p.ready() @@ -1298,36 +1368,65 @@ func TestBackgroundRead(t *testing.T) { assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) err = p.checkIn(conn) noerr(t, err) - wantErrs := []string{ - "error discarding 3 byte message: EOF", - } var bgErrs []error select { case bgErrs = <-errsCh: case <-time.After(3 * time.Second): assert.Fail(t, "did not receive expected error after waiting for 3 seconds") } - for i, err := range bgErrs { - if i < len(wantErrs) { - assert.EqualError(t, err, wantErrs[i], "mismatched err: %v", err) - } else { - assert.Fail(t, "unexpected error", "got unexpected error: %v", err) - } - } + require.Len(t, bgErrs, 0, "expected no error from bgRead()") }) -} + t.Run("timeout reading full message, background read EOF", func(t *testing.T) { + errsCh := make(chan []error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errsCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) -type limitConn struct { - net.Conn - r io.Reader -} + const timeout = 10 * time.Millisecond -func newLimitConn(conn net.Conn, n int64) limitConn { - return limitConn{conn, io.LimitReader(conn, n)} -} + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + defer func() { + _ = nc.Close() + }() + + var err error + _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) + noerr(t, err) + time.Sleep(timeout * 2) + // write an incomplete message + _, err = nc.Write([]byte{2}) + noerr(t, err) + }) + + p := newPool( + poolConfig{Address: address.Address(addr.String())}, + ) + defer p.close(context.Background()) + err := p.ready() + noerr(t, err) -func (lc limitConn) Read(b []byte) (n int, err error) { - return lc.r.Read(b) + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) + err = p.checkIn(conn) + noerr(t, err) + var bgErrs []error + select { + case bgErrs = <-errsCh: + case <-time.After(3 * time.Second): + assert.Fail(t, "did not receive expected error after waiting for 3 seconds") + } + require.Len(t, bgErrs, 1, "expected 1 error from bgRead()") + assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF") + }) } func assertConnectionsClosed(t *testing.T, dialer *dialer, count int) {