Skip to content

Commit 0eb881e

Browse files
author
Divjot Arora
committed
GODRIVER-1658 Fix connection updates to topology (#429)
1 parent 8064395 commit 0eb881e

File tree

8 files changed

+231
-104
lines changed

8 files changed

+231
-104
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@ require (
2525
golang.org/x/text v0.3.2 // indirect
2626
golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d
2727
)
28+
29+
go 1.13

mongo/integration/mtest/mongotest.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ type FailPointData struct {
6666
Name string `bson:"codeName"`
6767
Errmsg string `bson:"errmsg"`
6868
} `bson:"writeConcernError,omitempty"`
69+
BlockConnection bool `bson:"blockConnection,omitempty"`
70+
BlockTimeMS int32 `bson:"blockTimeMS,omitempty"`
71+
AppName string `bson:"appName,omitempty"`
6972
}
7073

7174
// T is a wrapper around testing.T.

mongo/integration/sdam_error_handling_test.go

Lines changed: 142 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,44 +24,154 @@ func TestSDAMErrorHandling(t *testing.T) {
2424
SetRetryWrites(false).
2525
SetPoolMonitor(poolMonitor).
2626
SetWriteConcern(mtest.MajorityWc)
27-
mtOpts := mtest.NewOptions().
28-
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
29-
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
30-
ClientOptions(clientOpts)
31-
32-
mt.RunOpts("network errors", mtOpts, func(mt *mtest.T) {
33-
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
34-
clearPoolChan()
35-
mt.SetFailPoint(mtest.FailPoint{
36-
ConfigureFailPoint: "failCommand",
37-
Mode: mtest.FailPointMode{
38-
Times: 1,
39-
},
40-
Data: mtest.FailPointData{
41-
FailCommands: []string{"insert"},
42-
CloseConnection: true,
43-
},
27+
baseMtOpts := func() *mtest.Options {
28+
mtOpts := mtest.NewOptions().
29+
Topologies(mtest.ReplicaSet). // Don't run on sharded clusters to avoid complexity of sharded failpoints.
30+
MinServerVersion("4.0"). // 4.0+ is required to use failpoints on replica sets.
31+
ClientOptions(clientOpts)
32+
33+
if mt.TopologyKind() == mtest.Sharded {
34+
// Pin to a single mongos because the tests use failpoints.
35+
mtOpts.ClientType(mtest.Pinned)
36+
}
37+
return mtOpts
38+
}
39+
40+
// Set min server version of 4.4 because the during-handshake tests use failpoint features introduced in 4.4 like
41+
// blockConnection and appName.
42+
mt.RunOpts("before handshake completes", baseMtOpts().Auth(true).MinServerVersion("4.4"), func(mt *mtest.T) {
43+
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
44+
mt.Run("pool cleared on network timeout", func(mt *mtest.T) {
45+
// Assert that the pool is cleared when a connection created by an application operation thread
46+
// encounters a network timeout during handshaking. Unlike the non-timeout test below, we only test
47+
// connections created in the foreground for timeouts because connections created by the pool
48+
// maintenance routine can't be timed out using a context.
49+
50+
appName := "authNetworkTimeoutTest"
51+
// Set failpoint on saslContinue instead of saslStart because saslStart isn't done when using
52+
// speculative auth.
53+
mt.SetFailPoint(mtest.FailPoint{
54+
ConfigureFailPoint: "failCommand",
55+
Mode: mtest.FailPointMode{
56+
Times: 1,
57+
},
58+
Data: mtest.FailPointData{
59+
FailCommands: []string{"saslContinue"},
60+
BlockConnection: true,
61+
BlockTimeMS: 150,
62+
AppName: appName,
63+
},
64+
})
65+
66+
// Reset the client with the appName specified in the failpoint.
67+
clientOpts := options.Client().
68+
SetAppName(appName).
69+
SetRetryWrites(false).
70+
SetPoolMonitor(poolMonitor)
71+
mt.ResetClient(clientOpts)
72+
clearPoolChan()
73+
74+
// The saslContinue blocks for 150ms so run the InsertOne with a 100ms context to cause a network
75+
// timeout during auth and assert that the pool was cleared.
76+
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
77+
defer cancel()
78+
_, err := mt.Coll.InsertOne(timeoutCtx, bson.D{{"test", 1}})
79+
assert.NotNil(mt, err, "expected InsertOne error, got nil")
80+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
4481
})
82+
mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) {
83+
mt.Run("background", func(mt *mtest.T) {
84+
// Assert that the pool is cleared when a connection created by the background pool maintenance
85+
// routine encounters a non-timeout network error during handshaking.
86+
appName := "authNetworkErrorTestBackground"
87+
88+
mt.SetFailPoint(mtest.FailPoint{
89+
ConfigureFailPoint: "failCommand",
90+
Mode: mtest.FailPointMode{
91+
Times: 1,
92+
},
93+
Data: mtest.FailPointData{
94+
FailCommands: []string{"saslContinue"},
95+
CloseConnection: true,
96+
AppName: appName,
97+
},
98+
})
99+
100+
clientOpts := options.Client().
101+
SetAppName(appName).
102+
SetMinPoolSize(5).
103+
SetPoolMonitor(poolMonitor)
104+
mt.ResetClient(clientOpts)
105+
clearPoolChan()
106+
107+
time.Sleep(200 * time.Millisecond)
108+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
109+
})
110+
mt.Run("foreground", func(mt *mtest.T) {
111+
// Assert that the pool is cleared when a connection created by an application thread connection
112+
// checkout encounters a non-timeout network error during handshaking.
113+
appName := "authNetworkErrorTestForeground"
45114

46-
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
47-
assert.NotNil(mt, err, "expected InsertOne error, got nil")
48-
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
115+
mt.SetFailPoint(mtest.FailPoint{
116+
ConfigureFailPoint: "failCommand",
117+
Mode: mtest.FailPointMode{
118+
Times: 1,
119+
},
120+
Data: mtest.FailPointData{
121+
FailCommands: []string{"saslContinue"},
122+
CloseConnection: true,
123+
AppName: appName,
124+
},
125+
})
126+
127+
clientOpts := options.Client().
128+
SetAppName(appName).
129+
SetPoolMonitor(poolMonitor)
130+
mt.ResetClient(clientOpts)
131+
clearPoolChan()
132+
133+
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
134+
assert.NotNil(mt, err, "expected InsertOne error, got nil")
135+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
136+
})
137+
})
49138
})
50-
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
51-
clearPoolChan()
139+
})
140+
mt.RunOpts("after handshake completes", baseMtOpts(), func(mt *mtest.T) {
141+
mt.RunOpts("network errors", noClientOpts, func(mt *mtest.T) {
142+
mt.Run("pool cleared on non-timeout network error", func(mt *mtest.T) {
143+
clearPoolChan()
144+
mt.SetFailPoint(mtest.FailPoint{
145+
ConfigureFailPoint: "failCommand",
146+
Mode: mtest.FailPointMode{
147+
Times: 1,
148+
},
149+
Data: mtest.FailPointData{
150+
FailCommands: []string{"insert"},
151+
CloseConnection: true,
152+
},
153+
})
52154

53-
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
54-
assert.Nil(mt, err, "InsertOne error: %v", err)
155+
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"test", 1}})
156+
assert.NotNil(mt, err, "expected InsertOne error, got nil")
157+
assert.True(mt, isPoolCleared(), "expected pool to be cleared but was not")
158+
})
159+
mt.Run("pool not cleared on timeout network error", func(mt *mtest.T) {
160+
clearPoolChan()
161+
162+
_, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}})
163+
assert.Nil(mt, err, "InsertOne error: %v", err)
55164

56-
filter := bson.M{
57-
"$where": "function() { sleep(1000); return false; }",
58-
}
59-
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
60-
defer cancel()
61-
_, err = mt.Coll.Find(timeoutCtx, filter)
62-
assert.NotNil(mt, err, "expected Find error, got %v", err)
165+
filter := bson.M{
166+
"$where": "function() { sleep(1000); return false; }",
167+
}
168+
timeoutCtx, cancel := context.WithTimeout(mtest.Background, 100*time.Millisecond)
169+
defer cancel()
170+
_, err = mt.Coll.Find(timeoutCtx, filter)
171+
assert.NotNil(mt, err, "expected Find error, got %v", err)
63172

64-
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
173+
assert.False(mt, isPoolCleared(), "expected pool to not be cleared but was")
174+
})
65175
})
66176
})
67177
}

mongo/integration/sessions_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ func TestSessions(t *testing.T) {
7272
CreateClient(false)
7373
mt := mtest.New(t, mtOpts)
7474

75-
clusterTimeOpts := mtest.NewOptions().ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
75+
// Pin to a single mongos so heartbeats/handshakes to other mongoses won't cause errors.
76+
// Pin to a single mongos so heartbeats/handshakes to other mongoses won't cause errors.
77+
clusterTimeOpts := mtest.NewOptions().
78+
ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
79+
ClientType(mtest.Pinned).
7680
CreateClient(false)
7781
mt.RunOpts("cluster time", clusterTimeOpts, func(mt *mtest.T) {
7882
// $clusterTime included in commands

x/mongo/driver/topology/connection.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ func newConnection(ctx context.Context, addr address.Address, opts ...Connection
8585
return c, nil
8686
}
8787

88+
func (c *connection) processInitializationError(err error) {
89+
atomic.StoreInt32(&c.connected, disconnected)
90+
if c.nc != nil {
91+
_ = c.nc.Close()
92+
}
93+
94+
c.connectErr = ConnectionError{Wrapped: err, init: true}
95+
if c.config.errorHandlingCallback != nil {
96+
c.config.errorHandlingCallback(c.connectErr)
97+
}
98+
}
99+
88100
// connect handles the I/O for a connection. It will dial, configure TLS, and perform
89101
// initialization handshakes.
90102
func (c *connection) connect(ctx context.Context) {
@@ -101,8 +113,7 @@ func (c *connection) connect(ctx context.Context) {
101113
var tempNc net.Conn
102114
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
103115
if err != nil {
104-
atomic.StoreInt32(&c.connected, disconnected)
105-
c.connectErr = ConnectionError{Wrapped: err, init: true}
116+
c.processInitializationError(err)
106117
return
107118
}
108119
c.nc = tempNc
@@ -114,11 +125,7 @@ func (c *connection) connect(ctx context.Context) {
114125
// error cases.
115126
tlsNc, err := configureTLS(ctx, c.nc, c.addr, tlsConfig)
116127
if err != nil {
117-
if c.nc != nil {
118-
_ = c.nc.Close()
119-
}
120-
atomic.StoreInt32(&c.connected, disconnected)
121-
c.connectErr = ConnectionError{Wrapped: err, init: true}
128+
c.processInitializationError(err)
122129
return
123130
}
124131
c.nc = tlsNc
@@ -138,17 +145,10 @@ func (c *connection) connect(ctx context.Context) {
138145
err = handshaker.FinishHandshake(ctx, handshakeConn)
139146
}
140147
if err != nil {
141-
if c.nc != nil {
142-
_ = c.nc.Close()
143-
}
144-
atomic.StoreInt32(&c.connected, disconnected)
145-
c.connectErr = ConnectionError{Wrapped: err, init: true}
148+
c.processInitializationError(err)
146149
return
147150
}
148151

149-
if c.config.descCallback != nil {
150-
c.config.descCallback(c.desc)
151-
}
152152
if len(c.desc.Compression) > 0 {
153153
clientMethodLoop:
154154
for _, method := range c.config.compressors {

x/mongo/driver/topology/connection_options.go

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88

99
"go.mongodb.org/mongo-driver/event"
1010
"go.mongodb.org/mongo-driver/x/mongo/driver"
11-
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
1211
)
1312

1413
// Dialer is used to make network connections.
@@ -36,20 +35,20 @@ var DefaultDialer Dialer = &net.Dialer{}
3635
type Handshaker = driver.Handshaker
3736

3837
type connectionConfig struct {
39-
appName string
40-
connectTimeout time.Duration
41-
dialer Dialer
42-
handshaker Handshaker
43-
idleTimeout time.Duration
44-
lifeTimeout time.Duration
45-
cmdMonitor *event.CommandMonitor
46-
readTimeout time.Duration
47-
writeTimeout time.Duration
48-
tlsConfig *tls.Config
49-
compressors []string
50-
zlibLevel *int
51-
zstdLevel *int
52-
descCallback func(description.Server)
38+
appName string
39+
connectTimeout time.Duration
40+
dialer Dialer
41+
handshaker Handshaker
42+
idleTimeout time.Duration
43+
lifeTimeout time.Duration
44+
cmdMonitor *event.CommandMonitor
45+
readTimeout time.Duration
46+
writeTimeout time.Duration
47+
tlsConfig *tls.Config
48+
compressors []string
49+
zlibLevel *int
50+
zstdLevel *int
51+
errorHandlingCallback func(error)
5352
}
5453

5554
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
@@ -73,16 +72,16 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
7372
return cfg, nil
7473
}
7574

76-
func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption {
77-
return append(opts, ConnectionOption(func(c *connectionConfig) error {
78-
c.descCallback = callback
79-
return nil
80-
}))
81-
}
82-
8375
// ConnectionOption is used to configure a connection.
8476
type ConnectionOption func(*connectionConfig) error
8577

78+
func withErrorHandlingCallback(fn func(error)) ConnectionOption {
79+
return func(c *connectionConfig) error {
80+
c.errorHandlingCallback = fn
81+
return nil
82+
}
83+
}
84+
8685
// WithCompressors sets the compressors that can be used for communication.
8786
func WithCompressors(fn func([]string) []string) ConnectionOption {
8887
return func(c *connectionConfig) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,32 +96,34 @@ func TestConnection(t *testing.T) {
9696
t.Errorf("errors do not match. got %v; want %v", got, want)
9797
}
9898
})
99-
t.Run("calls description callback", func(t *testing.T) {
100-
want := description.Server{Addr: address.Address("1.2.3.4:56789")}
101-
var got description.Server
99+
t.Run("calls error callback", func(t *testing.T) {
100+
handshakerError := errors.New("handshaker error")
101+
var got error
102+
102103
conn, err := newConnection(context.Background(), address.Address(""),
103-
withServerDescriptionCallback(func(desc description.Server) { got = desc },
104-
WithHandshaker(func(Handshaker) Handshaker {
105-
return &testHandshaker{
106-
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
107-
return want, nil
108-
},
109-
}
110-
}),
111-
WithDialer(func(Dialer) Dialer {
112-
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
113-
return &net.TCPConn{}, nil
114-
})
115-
}),
116-
)...,
104+
WithHandshaker(func(Handshaker) Handshaker {
105+
return &testHandshaker{
106+
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
107+
return description.Server{}, handshakerError
108+
},
109+
}
110+
}),
111+
WithDialer(func(Dialer) Dialer {
112+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
113+
return &net.TCPConn{}, nil
114+
})
115+
}),
116+
withErrorHandlingCallback(func(err error) {
117+
got = err
118+
}),
117119
)
118120
noerr(t, err)
119121
conn.connect(context.Background())
122+
123+
var want error = ConnectionError{Wrapped: handshakerError}
120124
err = conn.wait()
121-
noerr(t, err)
122-
if !cmp.Equal(got, want) {
123-
t.Errorf("Server descriptions do not match. got %v; want %v", got, want)
124-
}
125+
assert.NotNil(t, err, "expected connect error %v, got nil", want)
126+
assert.Equal(t, want, got, "expected error %v, got %v", want, got)
125127
})
126128
})
127129
t.Run("writeWireMessage", func(t *testing.T) {

0 commit comments

Comments
 (0)