Skip to content

Commit 2a5f9a4

Browse files
author
Divjot Arora
committed
GODRIVER-1879 Apply connectTimeoutMS to TLS handshake (#594)
1 parent 2c5b75b commit 2a5f9a4

File tree

4 files changed

+197
-11
lines changed

4 files changed

+197
-11
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,28 @@ func (c *connection) connect(ctx context.Context) {
115115
}
116116
defer close(c.connectDone)
117117

118+
// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
119+
//
120+
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
121+
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
122+
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
123+
//
124+
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
125+
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
126+
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
127+
// holding the lock longer than necessary.
118128
c.connectContextMutex.Lock()
119-
ctx, c.cancelConnectContext = context.WithCancel(ctx)
129+
var handshakeCtx context.Context
130+
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
120131
c.connectContextMutex.Unlock()
121132

133+
dialCtx := handshakeCtx
134+
var dialCancel context.CancelFunc
135+
if c.config.connectTimeout != 0 {
136+
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
137+
defer dialCancel()
138+
}
139+
122140
defer func() {
123141
var cancelFn context.CancelFunc
124142

@@ -137,7 +155,7 @@ func (c *connection) connect(ctx context.Context) {
137155
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
138156
var err error
139157
var tempNc net.Conn
140-
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
158+
tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
141159
if err != nil {
142160
c.processInitializationError(err)
143161
return
@@ -153,7 +171,7 @@ func (c *connection) connect(ctx context.Context) {
153171
Cache: c.config.ocspCache,
154172
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
155173
}
156-
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
174+
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
157175
if err != nil {
158176
c.processInitializationError(err)
159177
return
@@ -179,13 +197,13 @@ func (c *connection) connect(ctx context.Context) {
179197
var handshakeInfo driver.HandshakeInformation
180198
handshakeStartTime := time.Now()
181199
handshakeConn := initConnection{c}
182-
handshakeInfo, err = handshaker.GetHandshakeInformation(ctx, c.addr, handshakeConn)
200+
handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
183201
if err == nil {
184202
// We only need to retain the Description field as the connection's description. The authentication-related
185203
// fields in handshakeInfo are tracked by the handshaker if necessary.
186204
c.desc = handshakeInfo.Description
187205
c.isMasterRTT = time.Since(handshakeStartTime)
188-
err = handshaker.FinishHandshake(ctx, handshakeConn)
206+
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
189207
}
190208

191209
// We have a failed handshake here

x/mongo/driver/topology/connection_options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
7070
}
7171

7272
if cfg.dialer == nil {
73-
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
73+
cfg.dialer = &net.Dialer{}
7474
}
7575

7676
return cfg, nil

x/mongo/driver/topology/connection_test.go

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func TestConnection(t *testing.T) {
219219
for _, tc := range testCases {
220220
t.Run(tc.name, func(t *testing.T) {
221221
var sentCfg *tls.Config
222-
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
222+
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
223223
sentCfg = cfg
224224
return tls.Client(nc, cfg)
225225
}
@@ -252,6 +252,143 @@ func TestConnection(t *testing.T) {
252252
}
253253
})
254254
})
255+
t.Run("connectTimeout is applied correctly", func(t *testing.T) {
256+
testCases := []struct {
257+
name string
258+
contextTimeout time.Duration
259+
connectTimeout time.Duration
260+
maxConnectTime time.Duration
261+
}{
262+
// The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
263+
// both of the tests declared below. Both tests also specify a 10ms max connect time to provide
264+
// a large buffer for lag and avoid test flakiness.
265+
266+
{"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 10 * time.Millisecond},
267+
{"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 10 * time.Millisecond},
268+
}
269+
270+
for _, tc := range testCases {
271+
t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) {
272+
// Ensure the initial connection dial can be timed out and the connection propagates the error
273+
// from the dialer in this case.
274+
275+
connOpts := []ConnectionOption{
276+
WithDialer(func(Dialer) Dialer {
277+
return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
278+
<-ctx.Done()
279+
return nil, ctx.Err()
280+
})
281+
}),
282+
WithConnectTimeout(func(time.Duration) time.Duration {
283+
return tc.connectTimeout
284+
}),
285+
}
286+
conn, err := newConnection("", connOpts...)
287+
assert.Nil(t, err, "newConnection error: %v", err)
288+
289+
ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
290+
defer cancel()
291+
var connectErr error
292+
callback := func() {
293+
conn.connect(ctx)
294+
connectErr = conn.wait()
295+
}
296+
assert.Soon(t, callback, tc.maxConnectTime)
297+
298+
ce, ok := connectErr.(ConnectionError)
299+
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
300+
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
301+
context.DeadlineExceeded, ce.Unwrap())
302+
})
303+
t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) {
304+
// Ensure the TLS handshake can be timed out and the connection propagates the error from the
305+
// tlsConn in this case.
306+
307+
var hangingTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
308+
tlsConn := tls.Client(nc, cfg)
309+
return newHangingTLSConn(tlsConn, tc.maxConnectTime)
310+
}
311+
312+
connOpts := []ConnectionOption{
313+
WithConnectTimeout(func(time.Duration) time.Duration {
314+
return tc.connectTimeout
315+
}),
316+
WithDialer(func(Dialer) Dialer {
317+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
318+
return &net.TCPConn{}, nil
319+
})
320+
}),
321+
WithTLSConfig(func(*tls.Config) *tls.Config {
322+
return &tls.Config{}
323+
}),
324+
withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
325+
return hangingTLSConnectionSource
326+
}),
327+
}
328+
conn, err := newConnection("", connOpts...)
329+
assert.Nil(t, err, "newConnection error: %v", err)
330+
331+
ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
332+
defer cancel()
333+
var connectErr error
334+
callback := func() {
335+
conn.connect(ctx)
336+
connectErr = conn.wait()
337+
}
338+
assert.Soon(t, callback, tc.maxConnectTime)
339+
340+
ce, ok := connectErr.(ConnectionError)
341+
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
342+
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
343+
context.DeadlineExceeded, ce.Unwrap())
344+
})
345+
t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) {
346+
// Ensure that no additional timeout is applied to the handshake after the connection has been
347+
// established.
348+
349+
var getInfoCtx, finishCtx context.Context
350+
handshaker := &testHandshaker{
351+
getHandshakeInformation: func(ctx context.Context, _ address.Address, _ driver.Connection) (driver.HandshakeInformation, error) {
352+
getInfoCtx = ctx
353+
return driver.HandshakeInformation{}, nil
354+
},
355+
finishHandshake: func(ctx context.Context, _ driver.Connection) error {
356+
finishCtx = ctx
357+
return nil
358+
},
359+
}
360+
361+
connOpts := []ConnectionOption{
362+
WithConnectTimeout(func(time.Duration) time.Duration {
363+
return tc.connectTimeout
364+
}),
365+
WithDialer(func(Dialer) Dialer {
366+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
367+
return &net.TCPConn{}, nil
368+
})
369+
}),
370+
WithHandshaker(func(Handshaker) Handshaker {
371+
return handshaker
372+
}),
373+
}
374+
conn, err := newConnection("", connOpts...)
375+
assert.Nil(t, err, "newConnection error: %v", err)
376+
377+
bgCtx := context.Background()
378+
conn.connect(bgCtx)
379+
err = conn.wait()
380+
assert.Nil(t, err, "connect error: %v", err)
381+
382+
assertNoContextTimeout := func(t *testing.T, ctx context.Context) {
383+
t.Helper()
384+
dl, ok := ctx.Deadline()
385+
assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl)
386+
}
387+
assertNoContextTimeout(t, getInfoCtx)
388+
assertNoContextTimeout(t, finishCtx)
389+
})
390+
}
391+
})
255392
})
256393
t.Run("writeWireMessage", func(t *testing.T) {
257394
t.Run("closed connection", func(t *testing.T) {
@@ -993,3 +1130,24 @@ func (t *testCancellationListener) assertMethodsCalled(testingT *testing.T, numL
9931130
assert.Equal(testingT, numStopListening, t.numStopListening, "expected StopListening to be called %d times, got %d",
9941131
numListen, t.numListen)
9951132
}
1133+
1134+
// hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
1135+
// sleep for a fixed amount of time.
1136+
type hangingTLSConn struct {
1137+
*tls.Conn
1138+
sleepTime time.Duration
1139+
}
1140+
1141+
var _ tlsConn = (*hangingTLSConn)(nil)
1142+
1143+
func newHangingTLSConn(conn *tls.Conn, sleepTime time.Duration) *hangingTLSConn {
1144+
return &hangingTLSConn{
1145+
Conn: conn,
1146+
sleepTime: sleepTime,
1147+
}
1148+
}
1149+
1150+
func (h *hangingTLSConn) Handshake() error {
1151+
time.Sleep(h.sleepTime)
1152+
return h.Conn.Handshake()
1153+
}

x/mongo/driver/topology/tls_connection_source.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,26 @@ import (
1111
"net"
1212
)
1313

14+
type tlsConn interface {
15+
net.Conn
16+
Handshake() error
17+
ConnectionState() tls.ConnectionState
18+
}
19+
20+
var _ tlsConn = (*tls.Conn)(nil)
21+
1422
type tlsConnectionSource interface {
15-
Client(net.Conn, *tls.Config) *tls.Conn
23+
Client(net.Conn, *tls.Config) tlsConn
1624
}
1725

18-
type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn
26+
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn
27+
28+
var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)
1929

20-
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn {
30+
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
2131
return t(nc, cfg)
2232
}
2333

24-
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
34+
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
2535
return tls.Client(nc, cfg)
2636
}

0 commit comments

Comments
 (0)