Skip to content

Commit c335d0c

Browse files
author
Divjot Arora
authored
GODRIVER-1581 Implement Unwrap for all wrapping error types (#394)
1 parent 7b9dbec commit c335d0c

File tree

11 files changed

+257
-7
lines changed

11 files changed

+257
-7
lines changed

mongo/errors.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ func replaceErrors(err error) error {
3737
return ErrClientDisconnected
3838
}
3939
if de, ok := err.(driver.Error); ok {
40-
return CommandError{Code: de.Code, Message: de.Message, Labels: de.Labels, Name: de.Name}
40+
return CommandError{
41+
Code: de.Code,
42+
Message: de.Message,
43+
Labels: de.Labels,
44+
Name: de.Name,
45+
Wrapped: de.Wrapped,
46+
}
4147
}
4248
if qe, ok := err.(driver.QueryFailureError); ok {
4349
// qe.Message is "command failure"
@@ -83,6 +89,11 @@ func (ekve EncryptionKeyVaultError) Error() string {
8389
return fmt.Sprintf("key vault communication error: %v", ekve.Wrapped)
8490
}
8591

92+
// Unwrap returns the underlying error.
93+
func (ekve EncryptionKeyVaultError) Unwrap() error {
94+
return ekve.Wrapped
95+
}
96+
8697
// MongocryptdError represents an error while communicating with mongocryptd during client-side encryption.
8798
type MongocryptdError struct {
8899
Wrapped error
@@ -93,12 +104,18 @@ func (e MongocryptdError) Error() string {
93104
return fmt.Sprintf("mongocryptd communication error: %v", e.Wrapped)
94105
}
95106

107+
// Unwrap returns the underlying error.
108+
func (e MongocryptdError) Unwrap() error {
109+
return e.Wrapped
110+
}
111+
96112
// CommandError represents a server error during execution of a command. This can be returned by any operation.
97113
type CommandError struct {
98114
Code int32
99115
Message string
100116
Labels []string // Categories to which the error belongs
101117
Name string // A human-readable name corresponding to the error code
118+
Wrapped error // The underlying error, if one exists.
102119
}
103120

104121
// Error implements the error interface.
@@ -109,6 +126,11 @@ func (e CommandError) Error() string {
109126
return e.Message
110127
}
111128

129+
// Unwrap returns the underlying error.
130+
func (e CommandError) Unwrap() error {
131+
return e.Wrapped
132+
}
133+
112134
// HasErrorLabel returns true if the error contains the specified label.
113135
func (e CommandError) HasErrorLabel(label string) bool {
114136
if e.Labels != nil {

mongo/integration/errors_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
// +build go1.13
8+
9+
package integration
10+
11+
import (
12+
"context"
13+
"errors"
14+
"io"
15+
"testing"
16+
17+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
18+
"go.mongodb.org/mongo-driver/mongo"
19+
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
20+
"go.mongodb.org/mongo-driver/mongo/options"
21+
)
22+
23+
func TestErrors(t *testing.T) {
24+
mt := mtest.New(t, noClientOpts)
25+
defer mt.Close()
26+
27+
mt.RunOpts("errors are wrapped", noClientOpts, func(mt *mtest.T) {
28+
mt.Run("network error during application operation", func(mt *mtest.T) {
29+
ctx, cancel := context.WithCancel(context.Background())
30+
cancel()
31+
32+
err := mt.Client.Ping(ctx, mtest.PrimaryRp)
33+
assert.True(mt, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
34+
})
35+
36+
authOpts := mtest.NewOptions().Auth(true).Topologies(mtest.ReplicaSet, mtest.Single).MinServerVersion("4.0")
37+
mt.RunOpts("network error during auth", authOpts, func(mt *mtest.T) {
38+
mt.SetFailPoint(mtest.FailPoint{
39+
ConfigureFailPoint: "failCommand",
40+
Mode: mtest.FailPointMode{
41+
Times: 1,
42+
},
43+
Data: mtest.FailPointData{
44+
// Set the fail point for saslContinue rather than saslStart because the driver will use speculative
45+
// auth on 4.4+ so there won't be an explicit saslStart command.
46+
FailCommands: []string{"saslContinue"},
47+
CloseConnection: true,
48+
},
49+
})
50+
51+
client, err := mongo.Connect(mtest.Background, options.Client().ApplyURI(mt.ConnString()))
52+
assert.Nil(mt, err, "Connect error: %v", err)
53+
defer client.Disconnect(mtest.Background)
54+
55+
// A connection getting closed should manifest as an io.EOF error.
56+
err = client.Ping(mtest.Background, mtest.PrimaryRp)
57+
assert.True(mt, errors.Is(err, io.EOF), "expected error %v, got %v", io.EOF, err)
58+
})
59+
})
60+
}

x/mongo/driver/auth/auth.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ func (e *Error) Inner() error {
201201
return e.inner
202202
}
203203

204+
// Unwrap returns the underlying error.
205+
func (e *Error) Unwrap() error {
206+
return e.inner
207+
}
208+
204209
// Message returns the message.
205210
func (e *Error) Message() string {
206211
return e.message

x/mongo/driver/errors.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ func (e Error) Error() string {
220220
return e.Message
221221
}
222222

223+
// Unwrap returns the underlying error.
224+
func (e Error) Unwrap() error {
225+
return e.Wrapped
226+
}
227+
223228
// HasErrorLabel returns true if the error contains the specified label.
224229
func (e Error) HasErrorLabel(label string) bool {
225230
if e.Labels != nil {

x/mongo/driver/ocsp/ocsp.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,9 @@ func contactResponders(ctx context.Context, cfg config) (*ResponseDetails, error
271271

272272
timeout := urlErr.Timeout()
273273
cancelled := urlErr.Err == context.Canceled // Timeout() does not return true for context.Cancelled.
274-
if userContextUsed && (timeout || cancelled) {
275-
// Handle the original context expiring or being cancelled.
274+
if cancelled || (userContextUsed && timeout) {
275+
// Handle the original context expiring or being cancelled. The url.Error type supports Unwrap, so
276+
// users can use errors.Is to check for context errors.
276277
return err
277278
}
278279
return nil // Ignore all other errors.

x/mongo/driver/ocsp/ocsp_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
// +build go1.13
8+
9+
package ocsp
10+
11+
import (
12+
"context"
13+
"crypto/x509"
14+
"errors"
15+
"testing"
16+
17+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
18+
)
19+
20+
func TestOCSP(t *testing.T) {
21+
t.Run("contactResponders", func(t *testing.T) {
22+
t.Run("cancelled context is propagated", func(t *testing.T) {
23+
ctx, cancel := context.WithCancel(context.Background())
24+
cancel()
25+
26+
serverCert := &x509.Certificate{
27+
OCSPServer: []string{"https://localhost:5000"},
28+
}
29+
cfg := config{
30+
serverCert: serverCert,
31+
issuer: &x509.Certificate{},
32+
cache: NewCache(),
33+
}
34+
35+
_, err := contactResponders(ctx, cfg)
36+
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
37+
})
38+
})
39+
}

x/mongo/driver/topology/connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config
541541
return nil, ocspErr
542542
}
543543
case <-ctx.Done():
544-
return nil, errors.New("server connection cancelled/timeout during TLS handshake")
544+
return nil, ctx.Err()
545545
}
546546
return client, nil
547547
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
// +build go1.13
8+
9+
package topology
10+
11+
import (
12+
"context"
13+
"errors"
14+
"net"
15+
"testing"
16+
17+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
18+
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
19+
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
20+
)
21+
22+
func TestConnectionErrors(t *testing.T) {
23+
t.Run("errors are wrapped", func(t *testing.T) {
24+
t.Run("dial error", func(t *testing.T) {
25+
dialError := errors.New("foo")
26+
27+
conn, err := newConnection(context.Background(), address.Address(""), WithDialer(func(Dialer) Dialer {
28+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, dialError })
29+
}))
30+
assert.Nil(t, err, "newConnection error: %v", err)
31+
32+
conn.connect(context.Background())
33+
err = conn.wait()
34+
assert.True(t, errors.Is(err, dialError), "expected error %v, got %v", dialError, err)
35+
})
36+
t.Run("handshake error", func(t *testing.T) {
37+
conn, err := newConnection(context.Background(), address.Address(""),
38+
WithHandshaker(func(Handshaker) Handshaker {
39+
return auth.Handshaker(nil, &auth.HandshakeOptions{})
40+
}),
41+
WithDialer(func(Dialer) Dialer {
42+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
43+
return &net.TCPConn{}, nil
44+
})
45+
}),
46+
)
47+
assert.Nil(t, err, "newConnection error: %v", err)
48+
defer conn.close()
49+
50+
ctx, cancel := context.WithCancel(context.Background())
51+
cancel()
52+
conn.connect(ctx)
53+
err = conn.wait()
54+
55+
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
56+
})
57+
t.Run("write error", func(t *testing.T) {
58+
ctx, cancel := context.WithCancel(context.Background())
59+
cancel()
60+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
61+
err := conn.writeWireMessage(ctx, []byte{})
62+
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
63+
})
64+
t.Run("read error", func(t *testing.T) {
65+
ctx, cancel := context.WithCancel(context.Background())
66+
cancel()
67+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
68+
_, err := conn.readWireMessage(ctx, []byte{})
69+
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
70+
})
71+
})
72+
}

x/mongo/driver/topology/errors.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,8 @@ func (e ConnectionError) Error() string {
2020
}
2121
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.message)
2222
}
23+
24+
// Unwrap returns the underlying error.
25+
func (e ConnectionError) Unwrap() error {
26+
return e.Wrapped
27+
}

x/mongo/driver/topology/topology.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect
325325
if !doneOnce {
326326
// for the first pass, select a server from the current description.
327327
// this improves selection speed for up-to-date topology descriptions.
328-
suitable, selectErr = t.selectServerFromDescription(ctx, t.Description(), selectionState)
328+
suitable, selectErr = t.selectServerFromDescription(t.Description(), selectionState)
329329
doneOnce = true
330330
} else {
331331
// if the first pass didn't select a server, the previous description did not contain a suitable server, so
@@ -450,7 +450,7 @@ func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptio
450450
case current = <-subscriptionCh:
451451
}
452452

453-
suitable, err := t.selectServerFromDescription(ctx, current, selectionState)
453+
suitable, err := t.selectServerFromDescription(current, selectionState)
454454
if err != nil {
455455
return nil, err
456456
}
@@ -463,7 +463,7 @@ func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptio
463463
}
464464

465465
// selectServerFromDescription process the given topology description and returns a slice of suitable servers.
466-
func (t *Topology) selectServerFromDescription(ctx context.Context, desc description.Topology,
466+
func (t *Topology) selectServerFromDescription(desc description.Topology,
467467
selectionState serverSelectionState) ([]description.Server, error) {
468468

469469
// Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because

0 commit comments

Comments
 (0)