Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (c *connection) connect(ctx context.Context) (err error) {
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
tempNc, err := c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
if err != nil {
return ConnectionError{Wrapped: err, init: true}
return ConnectionError{Wrapped: err, init: true, message: fmt.Sprintf("failed to connect to %s", c.addr)}
}
c.nc = tempNc

Expand All @@ -229,7 +229,7 @@ func (c *connection) connect(ctx context.Context) (err error) {
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)

if err != nil {
return ConnectionError{Wrapped: err, init: true}
return ConnectionError{Wrapped: err, init: true, message: fmt.Sprintf("failed to configure TLS for %s", c.addr)}
}
c.nc = tlsNc
}
Expand Down Expand Up @@ -341,7 +341,10 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
return originalError
}
if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
return fmt.Errorf("%w: %s", context.DeadlineExceeded, originalError.Error())
return fmt.Errorf("%w: %s: %s",
context.DeadlineExceeded,
"client timed out waiting for server response",
originalError.Error())
}

return originalError
Expand Down Expand Up @@ -413,9 +416,6 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
c.close()
}
message := errMsg
if errors.Is(err, io.EOF) {
message = "socket was unexpectedly closed"
}
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ func TestConnection(t *testing.T) {
t.Run("connect", func(t *testing.T) {
t.Run("dialer error", func(t *testing.T) {
err := errors.New("dialer error")
var want error = ConnectionError{Wrapped: err, init: true}
conn := newConnection(address.Address(""), WithDialer(func(Dialer) Dialer {
var want error = ConnectionError{Wrapped: err, init: true, message: "failed to connect to testaddr:27017"}
conn := newConnection(address.Address("testaddr"), WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, err })
}))
got := conn.connect(context.Background())
Expand Down
33 changes: 23 additions & 10 deletions x/mongo/driver/topology/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"time"

"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
)

var _ error = ConnectionError{}

// ConnectionError represents a connection error.
type ConnectionError struct {
ConnectionID string
Expand All @@ -28,21 +34,28 @@ type ConnectionError struct {

// Error implements the error interface.
func (e ConnectionError) Error() string {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding a compile check for ConnectionError, I just noticed we don't do this:

var _ error = ConnectionError{}

message := e.message
var messages []string
if e.init {
fullMsg := "error occurred during connection handshake"
if message != "" {
fullMsg = fmt.Sprintf("%s: %s", fullMsg, message)
}
message = fullMsg
messages = append(messages, "error occurred during connection handshake")
}
if e.Wrapped != nil && message != "" {
return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, message, e.Wrapped.Error())
if e.message != "" {
messages = append(messages, e.message)
}
if e.Wrapped != nil {
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.Wrapped.Error())
if errors.Is(e.Wrapped, io.EOF) {
messages = append(messages, "connection closed unexpectedly by the other side")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should create integration tests that verify all three branches work as expected. For the io.EOF you could do something like this:

	addr := bootstrapConnections(t, 1, func(nc net.Conn) {
		_ = nc.Close() // Close the connection "server-side" / "other-side"
	})

	p := newPool(
		poolConfig{Address: address.Address(addr.String())},
	)
	defer p.close(context.Background())
	err := p.ready()
	require.NoError(t, err)

	conn, err := p.checkOut(context.Background())
	fmt.Println(err)

	_, err = conn.readWireMessage(context.Background())
	fmt.Println(err) // expect io.EOF

The timeout branches would just require holding a server response > than the timeout. I know we already do this with TestBackgroundRead but that is a secondary check of those tests. We should verify this in a test designed specifically to do som.

}
if errors.Is(e.Wrapped, os.ErrDeadlineExceeded) {
messages = append(messages, "client timed out waiting for server response")
} else if err, ok := e.Wrapped.(net.Error); ok && err.Timeout() {
messages = append(messages, "client timed out waiting for server response")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the transformNetworkError function will wrap the error with context.DeadlineExceeded. Should we include that case here for safety?

messages = append(messages, e.Wrapped.Error())
}
if len(messages) > 0 {
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, strings.Join(messages, ": "))
}
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, message)
return fmt.Sprintf("connection(%s)", e.ConnectionID)
}

// Unwrap returns the underlying error.
Expand Down
15 changes: 8 additions & 7 deletions x/mongo/driver/topology/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ func TestPool_checkOut(t *testing.T) {

dialErr := errors.New("create new connection error")
p := newPool(poolConfig{
Address: "testaddr",
ConnectTimeout: defaultConnectionTimeout,
}, WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
Expand All @@ -481,7 +482,7 @@ func TestPool_checkOut(t *testing.T) {
require.NoError(t, err)

_, err = p.checkOut(context.Background())
var want error = ConnectionError{Wrapped: dialErr, init: true}
var want error = ConnectionError{Wrapped: dialErr, init: true, message: "failed to connect to testaddr:27017"}
assert.Equalf(t, want, err, "should return error from calling checkOut()")
// If a connection initialization error occurs during checkOut, removing and closing the
// failed connection both happen asynchronously with the checkOut. Wait for up to 2s for
Expand Down Expand Up @@ -1278,7 +1279,7 @@ func TestBackgroundRead(t *testing.T) {
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil")
Expand Down Expand Up @@ -1318,7 +1319,7 @@ func TestBackgroundRead(t *testing.T) {
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
Expand Down Expand Up @@ -1365,7 +1366,7 @@ func TestBackgroundRead(t *testing.T) {
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
Expand Down Expand Up @@ -1417,7 +1418,7 @@ func TestBackgroundRead(t *testing.T) {
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
Expand Down Expand Up @@ -1471,7 +1472,7 @@ func TestBackgroundRead(t *testing.T) {
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
Expand Down Expand Up @@ -1521,7 +1522,7 @@ func TestBackgroundRead(t *testing.T) {
defer cancel()
_, err = conn.readWireMessage(ctx)
regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: client timed out waiting for server response: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)
err = p.checkIn(conn)
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ func TestServer(t *testing.T) {
}

authErr := ConnectionError{Wrapped: &auth.Error{}, init: true}
netErr := ConnectionError{Wrapped: &net.AddrError{}, init: true}
netErr := ConnectionError{Wrapped: &net.AddrError{}, init: true, message: "failed to connect to localhost:27017"}
for _, tt := range serverTestTable {
t.Run(tt.name, func(t *testing.T) {
var returnConnectionError bool
Expand Down
Loading