@@ -219,7 +219,7 @@ func TestConnection(t *testing.T) {
219
219
for _ , tc := range testCases {
220
220
t .Run (tc .name , func (t * testing.T ) {
221
221
var sentCfg * tls.Config
222
- var testTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) * tls. Conn {
222
+ var testTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) tlsConn {
223
223
sentCfg = cfg
224
224
return tls .Client (nc , cfg )
225
225
}
@@ -252,6 +252,143 @@ func TestConnection(t *testing.T) {
252
252
}
253
253
})
254
254
})
255
+ t .Run ("connectTimeout is applied correctly" , func (t * testing.T ) {
256
+ testCases := []struct {
257
+ name string
258
+ contextTimeout time.Duration
259
+ connectTimeout time.Duration
260
+ maxConnectTime time.Duration
261
+ }{
262
+ // The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
263
+ // both of the tests declared below. Both tests also specify a 10ms max connect time to provide
264
+ // a large buffer for lag and avoid test flakiness.
265
+
266
+ {"context timeout is lower" , 1 * time .Millisecond , 100 * time .Millisecond , 10 * time .Millisecond },
267
+ {"connect timeout is lower" , 100 * time .Millisecond , 1 * time .Millisecond , 10 * time .Millisecond },
268
+ }
269
+
270
+ for _ , tc := range testCases {
271
+ t .Run ("timeout applied to socket establishment: " + tc .name , func (t * testing.T ) {
272
+ // Ensure the initial connection dial can be timed out and the connection propagates the error
273
+ // from the dialer in this case.
274
+
275
+ connOpts := []ConnectionOption {
276
+ WithDialer (func (Dialer ) Dialer {
277
+ return DialerFunc (func (ctx context.Context , _ , _ string ) (net.Conn , error ) {
278
+ <- ctx .Done ()
279
+ return nil , ctx .Err ()
280
+ })
281
+ }),
282
+ WithConnectTimeout (func (time.Duration ) time.Duration {
283
+ return tc .connectTimeout
284
+ }),
285
+ }
286
+ conn , err := newConnection ("" , connOpts ... )
287
+ assert .Nil (t , err , "newConnection error: %v" , err )
288
+
289
+ ctx , cancel := context .WithTimeout (context .Background (), tc .contextTimeout )
290
+ defer cancel ()
291
+ var connectErr error
292
+ callback := func () {
293
+ conn .connect (ctx )
294
+ connectErr = conn .wait ()
295
+ }
296
+ assert .Soon (t , callback , tc .maxConnectTime )
297
+
298
+ ce , ok := connectErr .(ConnectionError )
299
+ assert .True (t , ok , "expected error %v to be of type %T" , connectErr , ConnectionError {})
300
+ assert .Equal (t , context .DeadlineExceeded , ce .Unwrap (), "expected wrapped error to be %v, got %v" ,
301
+ context .DeadlineExceeded , ce .Unwrap ())
302
+ })
303
+ t .Run ("timeout applied to TLS handshake: " + tc .name , func (t * testing.T ) {
304
+ // Ensure the TLS handshake can be timed out and the connection propagates the error from the
305
+ // tlsConn in this case.
306
+
307
+ var hangingTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) tlsConn {
308
+ tlsConn := tls .Client (nc , cfg )
309
+ return newHangingTLSConn (tlsConn , tc .maxConnectTime )
310
+ }
311
+
312
+ connOpts := []ConnectionOption {
313
+ WithConnectTimeout (func (time.Duration ) time.Duration {
314
+ return tc .connectTimeout
315
+ }),
316
+ WithDialer (func (Dialer ) Dialer {
317
+ return DialerFunc (func (context.Context , string , string ) (net.Conn , error ) {
318
+ return & net.TCPConn {}, nil
319
+ })
320
+ }),
321
+ WithTLSConfig (func (* tls.Config ) * tls.Config {
322
+ return & tls.Config {}
323
+ }),
324
+ withTLSConnectionSource (func (tlsConnectionSource ) tlsConnectionSource {
325
+ return hangingTLSConnectionSource
326
+ }),
327
+ }
328
+ conn , err := newConnection ("" , connOpts ... )
329
+ assert .Nil (t , err , "newConnection error: %v" , err )
330
+
331
+ ctx , cancel := context .WithTimeout (context .Background (), tc .contextTimeout )
332
+ defer cancel ()
333
+ var connectErr error
334
+ callback := func () {
335
+ conn .connect (ctx )
336
+ connectErr = conn .wait ()
337
+ }
338
+ assert .Soon (t , callback , tc .maxConnectTime )
339
+
340
+ ce , ok := connectErr .(ConnectionError )
341
+ assert .True (t , ok , "expected error %v to be of type %T" , connectErr , ConnectionError {})
342
+ assert .Equal (t , context .DeadlineExceeded , ce .Unwrap (), "expected wrapped error to be %v, got %v" ,
343
+ context .DeadlineExceeded , ce .Unwrap ())
344
+ })
345
+ t .Run ("timeout is not applied to handshaker: " + tc .name , func (t * testing.T ) {
346
+ // Ensure that no additional timeout is applied to the handshake after the connection has been
347
+ // established.
348
+
349
+ var getInfoCtx , finishCtx context.Context
350
+ handshaker := & testHandshaker {
351
+ getHandshakeInformation : func (ctx context.Context , _ address.Address , _ driver.Connection ) (driver.HandshakeInformation , error ) {
352
+ getInfoCtx = ctx
353
+ return driver.HandshakeInformation {}, nil
354
+ },
355
+ finishHandshake : func (ctx context.Context , _ driver.Connection ) error {
356
+ finishCtx = ctx
357
+ return nil
358
+ },
359
+ }
360
+
361
+ connOpts := []ConnectionOption {
362
+ WithConnectTimeout (func (time.Duration ) time.Duration {
363
+ return tc .connectTimeout
364
+ }),
365
+ WithDialer (func (Dialer ) Dialer {
366
+ return DialerFunc (func (context.Context , string , string ) (net.Conn , error ) {
367
+ return & net.TCPConn {}, nil
368
+ })
369
+ }),
370
+ WithHandshaker (func (Handshaker ) Handshaker {
371
+ return handshaker
372
+ }),
373
+ }
374
+ conn , err := newConnection ("" , connOpts ... )
375
+ assert .Nil (t , err , "newConnection error: %v" , err )
376
+
377
+ bgCtx := context .Background ()
378
+ conn .connect (bgCtx )
379
+ err = conn .wait ()
380
+ assert .Nil (t , err , "connect error: %v" , err )
381
+
382
+ assertNoContextTimeout := func (t * testing.T , ctx context.Context ) {
383
+ t .Helper ()
384
+ dl , ok := ctx .Deadline ()
385
+ assert .False (t , ok , "expected context to have no deadline, but got deadline %v" , dl )
386
+ }
387
+ assertNoContextTimeout (t , getInfoCtx )
388
+ assertNoContextTimeout (t , finishCtx )
389
+ })
390
+ }
391
+ })
255
392
})
256
393
t .Run ("writeWireMessage" , func (t * testing.T ) {
257
394
t .Run ("closed connection" , func (t * testing.T ) {
@@ -993,3 +1130,24 @@ func (t *testCancellationListener) assertMethodsCalled(testingT *testing.T, numL
993
1130
assert .Equal (testingT , numStopListening , t .numStopListening , "expected StopListening to be called %d times, got %d" ,
994
1131
numListen , t .numListen )
995
1132
}
1133
+
1134
+ // hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
1135
+ // sleep for a fixed amount of time.
1136
+ type hangingTLSConn struct {
1137
+ * tls.Conn
1138
+ sleepTime time.Duration
1139
+ }
1140
+
1141
+ var _ tlsConn = (* hangingTLSConn )(nil )
1142
+
1143
+ func newHangingTLSConn (conn * tls.Conn , sleepTime time.Duration ) * hangingTLSConn {
1144
+ return & hangingTLSConn {
1145
+ Conn : conn ,
1146
+ sleepTime : sleepTime ,
1147
+ }
1148
+ }
1149
+
1150
+ func (h * hangingTLSConn ) Handshake () error {
1151
+ time .Sleep (h .sleepTime )
1152
+ return h .Conn .Handshake ()
1153
+ }
0 commit comments