Skip to content

Commit 34d063f

Browse files
author
Divjot Arora
authored
GODRIVER-1658 Fix connection updates to topology (#429)
1 parent 7d3355f commit 34d063f

File tree

7 files changed

+216
-91
lines changed

7 files changed

+216
-91
lines changed

mongo/integration/mtest/mongotest.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ type FailPointData struct {
6464
FailBeforeCommitExceptionCode int32 `bson:"failBeforeCommitExceptionCode,omitempty"`
6565
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
6666
WriteConcernError *WriteConcernErrorData `bson:"writeConcernError,omitempty"`
67+
BlockConnection bool `bson:"blockConnection,omitempty"`
68+
BlockTimeMS int32 `bson:"blockTimeMS,omitempty"`
69+
AppName string `bson:"appName,omitempty"`
6770
}
6871

6972
// WriteConcernErrorData is a representation of the FailPoint.Data.WriteConcern field.

mongo/integration/sdam_error_handling_test.go

Lines changed: 142 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,44 +24,154 @@ func TestSDAMErrorHandling(t *testing.T) {
2424
SetRetryWrites(false).
2525
SetPoolMonitor(poolMonitor).
2626
SetWriteConcern(mtest.MajorityWc)
27-
mtOpts := mtest.NewOptions().
28-
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
29-
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
30-
ClientOptions(clientOpts)
31-
32-
mt.RunOpts("network errors", mtOpts, func(mt *mtest.T) {
33-
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
34-
clearPoolChan()
35-
mt.SetFailPoint(mtest.FailPoint{
36-
ConfigureFailPoint: "failCommand",
37-
Mode: mtest.FailPointMode{
38-
Times: 1,
39-
},
40-
Data: mtest.FailPointData{
41-
FailCommands: []string{"insert"},
42-
CloseConnection: true,
43-
},
27+
baseMtOpts := func() *mtest.Options {
28+
mtOpts := mtest.NewOptions().
29+
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
30+
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
31+
ClientOptions(clientOpts)
32+
33+
if mt.TopologyKind() == mtest.Sharded {
34+
// Pin to a single mongos because the tests use failpoints.
35+
mtOpts.ClientType(mtest.Pinned)
36+
}
37+
return mtOpts
38+
}
39+
40+
// Set min server version of 4.4 because the during-handshake tests use failpoint features introduced in 4.4 like
41+
// blockConnection and appName.
42+
mt.RunOpts("before handshake completes", baseMtOpts().Auth(true).MinServerVersion("4.4"), func(mt *mtest.T) {
43+
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
44+
mt.Run("pool cleared on network timeout", func(mt *mtest.T) {
45+
// Assert that the pool is cleared when a connection created by an application operation thread
46+
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
47+
// connections created in the foreground for timeouts because connections created by the pool
48+
// maintenance routine can't be timed out using a context.
49+
50+
appName := "authNetworkTimeoutTest"
51+
// Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using
52+
// speculative auth.
53+
mt.SetFailPoint(mtest.FailPoint{
54+
ConfigureFailPoint: "failCommand",
55+
Mode: mtest.FailPointMode{
56+
Times: 1,
57+
},
58+
Data: mtest.FailPointData{
59+
FailCommands: []string{"saslContinue"},
60+
BlockConnection: true,
61+
BlockTimeMS: 150,
62+
AppName: appName,
63+
},
64+
})
65+
66+
// Reset the client with the appName specified in the failpoint.
67+
clientOpts := options.Client().
68+
SetAppName(appName).
69+
SetRetryWrites(false).
70+
SetPoolMonitor(poolMonitor)
71+
mt.ResetClient(clientOpts)
72+
clearPoolChan()
73+
74+
// The saslContinue blocks for 150ms so run the InsertOne with a 100ms context to cause a network
75+
// timeout during auth and assert that the pool was cleared.
76+
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
77+
defer cancel()
78+
_, err := mt.Coll.InsertOne(timeoutCtx, bson.D{{"test", 1}})
79+
assert.NotNil(mt, err, "expected InsertOne error, got nil")
80+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
4481
})
82+
mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) {
83+
mt.Run("background", func(mt *mtest.T) {
84+
// Assert that the pool is cleared when a connection created by the background pool maintenance
85+
// routine encounters a non-timeout network error during handshaking.
86+
appName := "authNetworkErrorTestBackground"
87+
88+
mt.SetFailPoint(mtest.FailPoint{
89+
ConfigureFailPoint: "failCommand",
90+
Mode: mtest.FailPointMode{
91+
Times: 1,
92+
},
93+
Data: mtest.FailPointData{
94+
FailCommands: []string{"saslContinue"},
95+
CloseConnection: true,
96+
AppName: appName,
97+
},
98+
})
99+
100+
clientOpts := options.Client().
101+
SetAppName(appName).
102+
SetMinPoolSize(5).
103+
SetPoolMonitor(poolMonitor)
104+
mt.ResetClient(clientOpts)
105+
clearPoolChan()
106+
107+
time.Sleep(200 * time.Millisecond)
108+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
109+
})
110+
mt.Run("foreground", func(mt *mtest.T) {
111+
// Assert that the pool is cleared when a connection created by an application thread connection
112+
// checkout encounters a non-timeout network error during handshaking.
113+
appName := "authNetworkErrorTestForeground"
45114

46-
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
47-
assert.NotNil(mt, err, "expected InsertOne error, got nil")
48-
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
115+
mt.SetFailPoint(mtest.FailPoint{
116+
ConfigureFailPoint: "failCommand",
117+
Mode: mtest.FailPointMode{
118+
Times: 1,
119+
},
120+
Data: mtest.FailPointData{
121+
FailCommands: []string{"saslContinue"},
122+
CloseConnection: true,
123+
AppName: appName,
124+
},
125+
})
126+
127+
clientOpts := options.Client().
128+
SetAppName(appName).
129+
SetPoolMonitor(poolMonitor)
130+
mt.ResetClient(clientOpts)
131+
clearPoolChan()
132+
133+
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
134+
assert.NotNil(mt, err, "expected InsertOne error, got nil")
135+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
136+
})
137+
})
49138
})
50-
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
51-
clearPoolChan()
139+
})
140+
mt.RunOpts("after handshake completes", baseMtOpts(), func(mt *mtest.T) {
141+
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
142+
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
143+
clearPoolChan()
144+
mt.SetFailPoint(mtest.FailPoint{
145+
ConfigureFailPoint: "failCommand",
146+
Mode: mtest.FailPointMode{
147+
Times: 1,
148+
},
149+
Data: mtest.FailPointData{
150+
FailCommands: []string{"insert"},
151+
CloseConnection: true,
152+
},
153+
})
52154

53-
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
54-
assert.Nil(mt, err, "InsertOne error: %v", err)
155+
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
156+
assert.NotNil(mt, err, "expected InsertOne error, got nil")
157+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
158+
})
159+
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
160+
clearPoolChan()
161+
162+
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
163+
assert.Nil(mt, err, "InsertOne error: %v", err)
55164

56-
filter := bson.M{
57-
"$where": "function() { sleep(1000); return false; }",
58-
}
59-
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
60-
defer cancel()
61-
_, err = mt.Coll.Find(timeoutCtx, filter)
62-
assert.NotNil(mt, err, "expected Find error, got %v", err)
165+
filter := bson.M{
166+
"$where": "function() { sleep(1000); return false; }",
167+
}
168+
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
169+
defer cancel()
170+
_, err = mt.Coll.Find(timeoutCtx, filter)
171+
assert.NotNil(mt, err, "expected Find error, got %v", err)
63172

64-
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
173+
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
174+
})
65175
})
66176
})
67177
}

mongo/integration/sessions_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ func TestSessions(t *testing.T) {
7171
CreateClient(false)
7272
mt := mtest.New(t, mtOpts)
7373

74-
clusterTimeOpts := mtest.NewOptions().ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
74+
// Pin to a single mongos so heartbeats/handshakes to other mongoses won't cause errors.
75+
// Pin to a single mongos so heartbeats/handshakes to other mongoses won't cause errors.
76+
clusterTimeOpts := mtest.NewOptions().
77+
ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
78+
ClientType(mtest.Pinned).
7579
CreateClient(false)
7680
mt.RunOpts("cluster time", clusterTimeOpts, func(mt *mtest.T) {
7781
// $clusterTime included in commands

x/mongo/driver/topology/connection.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
8888
return c, nil
8989
}
9090

91+
func (c *connection) processInitializationError(err error) {
92+
atomic.StoreInt32(&c.connected, disconnected)
93+
if c.nc != nil {
94+
_ = c.nc.Close()
95+
}
96+
97+
c.connectErr = ConnectionError{Wrapped: err, init: true}
98+
if c.config.errorHandlingCallback != nil {
99+
c.config.errorHandlingCallback(c.connectErr)
100+
}
101+
}
102+
91103
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
92104
// initialization handshakes.
93105
func (c *connection) connect(ctx context.Context) {
@@ -104,8 +116,7 @@ func (c *connection) connect(ctx context.Context) {
104116
var tempNc net.Conn
105117
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
106118
if err != nil {
107-
atomic.StoreInt32(&c.connected, disconnected)
108-
c.connectErr = ConnectionError{Wrapped: err, init: true}
119+
c.processInitializationError(err)
109120
return
110121
}
111122
c.nc = tempNc
@@ -121,11 +132,7 @@ func (c *connection) connect(ctx context.Context) {
121132
}
122133
tlsNc, err := configureTLS(ctx, c.nc, c.addr, tlsConfig, ocspOpts)
123134
if err != nil {
124-
if c.nc != nil {
125-
_ = c.nc.Close()
126-
}
127-
atomic.StoreInt32(&c.connected, disconnected)
128-
c.connectErr = ConnectionError{Wrapped: err, init: true}
135+
c.processInitializationError(err)
129136
return
130137
}
131138
c.nc = tlsNc
@@ -145,17 +152,10 @@ func (c *connection) connect(ctx context.Context) {
145152
err = handshaker.FinishHandshake(ctx, handshakeConn)
146153
}
147154
if err != nil {
148-
if c.nc != nil {
149-
_ = c.nc.Close()
150-
}
151-
atomic.StoreInt32(&c.connected, disconnected)
152-
c.connectErr = ConnectionError{Wrapped: err, init: true}
155+
c.processInitializationError(err)
153156
return
154157
}
155158

156-
if c.config.descCallback != nil {
157-
c.config.descCallback(c.desc)
158-
}
159159
if len(c.desc.Compression) > 0 {
160160
clientMethodLoop:
161161
for _, method := range c.config.compressors {

x/mongo/driver/topology/connection_options.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88

99
"go.mongodb.org/mongo-driver/event"
1010
"go.mongodb.org/mongo-driver/x/mongo/driver"
11-
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
1211
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
1312
)
1413

@@ -50,9 +49,9 @@ type connectionConfig struct {
5049
compressors []string
5150
zlibLevel *int
5251
zstdLevel *int
53-
descCallback func(description.Server)
5452
ocspCache ocsp.Cache
5553
disableOCSPEndpointCheck bool
54+
errorHandlingCallback func(error)
5655
}
5756

5857
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
@@ -76,16 +75,16 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
7675
return cfg, nil
7776
}
7877

79-
func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption {
80-
return append(opts, ConnectionOption(func(c *connectionConfig) error {
81-
c.descCallback = callback
82-
return nil
83-
}))
84-
}
85-
8678
// ConnectionOption is used to configure a connection.
8779
type ConnectionOption func(*connectionConfig) error
8880

81+
func withErrorHandlingCallback(fn func(error)) ConnectionOption {
82+
return func(c *connectionConfig) error {
83+
c.errorHandlingCallback = fn
84+
return nil
85+
}
86+
}
87+
8988
// WithCompressors sets the compressors that can be used for communication.
9089
func WithCompressors(fn func([]string) []string) ConnectionOption {
9190
return func(c *connectionConfig) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,32 +96,34 @@ func TestConnection(t *testing.T) {
9696
t.Errorf("errors do not match. got %v; want %v", got, want)
9797
}
9898
})
99-
t.Run("calls description callback", func(t *testing.T) {
100-
want := description.Server{Addr: address.Address("1.2.3.4:56789")}
101-
var got description.Server
99+
t.Run("calls error callback", func(t *testing.T) {
100+
handshakerError := errors.New("handshaker error")
101+
var got error
102+
102103
conn, err := newConnection(context.Background(), address.Address(""),
103-
withServerDescriptionCallback(func(desc description.Server) { got = desc },
104-
WithHandshaker(func(Handshaker) Handshaker {
105-
return &testHandshaker{
106-
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
107-
return want, nil
108-
},
109-
}
110-
}),
111-
WithDialer(func(Dialer) Dialer {
112-
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
113-
return &net.TCPConn{}, nil
114-
})
115-
}),
116-
)...,
104+
WithHandshaker(func(Handshaker) Handshaker {
105+
return &testHandshaker{
106+
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
107+
return description.Server{}, handshakerError
108+
},
109+
}
110+
}),
111+
WithDialer(func(Dialer) Dialer {
112+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
113+
return &net.TCPConn{}, nil
114+
})
115+
}),
116+
withErrorHandlingCallback(func(err error) {
117+
got = err
118+
}),
117119
)
118120
noerr(t, err)
119121
conn.connect(context.Background())
122+
123+
var want error = ConnectionError{Wrapped: handshakerError}
120124
err = conn.wait()
121-
noerr(t, err)
122-
if !cmp.Equal(got, want) {
123-
t.Errorf("Server descriptions do not match. got %v; want %v", got, want)
124-
}
125+
assert.NotNil(t, err, "expected connect error %v, got nil", want)
126+
assert.Equal(t, want, got, "expected error %v, got %v", want, got)
125127
})
126128
})
127129
t.Run("writeWireMessage", func(t *testing.T) {

0 commit comments

Comments
 (0)