@@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
418418 tlsConn := tls .Client (conn , ci .TLSConfig ())
419419 err = tlsConn .HandshakeContext (ctx )
420420 if err != nil {
421+ // TLS handshake errors are fatal and require a refresh. Remove the instance
422+ // from the cache so that future calls to Dial() will block until the
423+ // certificate is refreshed successfully.
421424 d .logger .Debugf (ctx , "[%v] TLS handshake failed: %v" , cn .String (), err )
422- // refresh the instance info in case it caused the handshake failure
423- c .ForceRefresh ()
425+ d .removeCached (ctx , cn , c , err )
424426 _ = tlsConn .Close () // best effort close attempt
425427 return nil , errtype .NewDialError ("handshake failed" , cn .String (), err )
426428 }
@@ -435,7 +437,22 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
435437 iConn := newInstrumentedConn (tlsConn , func () {
436438 n := atomic .AddUint64 (c .openConnsCount , ^ uint64 (0 )) // c.openConnsCount = c.openConnsCount - 1
437439 trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
438- }, d .dialerID , cn .String ())
440+ },
441+ func (err error ) {
442+ // ignore EOF
443+ if err == io .EOF {
444+ return
445+ }
446+ d .logger .Debugf (ctx , "[%v] IO Error on Read or Write: %v" , cn .String (), err )
447+ if d .isTLSError (err ) {
448+ // TLS handshake errors are fatal. Remove the instance from the cache
449+ // so that future calls to Dial() will block until the certificate
450+ // is refreshed successfully.
451+ d .removeCached (ctx , cn , c , err )
452+ _ = tlsConn .Close () // best effort close attempt
453+ }
454+ },
455+ d .dialerID , cn .String ())
439456
440457 // If this connection was opened using a Domain Name, then store it for later
441458 // in case it needs to be forcibly closed.
@@ -446,12 +463,19 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
446463 }
447464 return iConn , nil
448465}
466+ func (d * Dialer ) isTLSError (err error ) bool {
467+ if nErr , ok := err .(net.Error ); ok {
468+ return ! nErr .Timeout () && // it's a permanent net error
469+ strings .Contains (nErr .Error (), "tls" ) // it's a TLS-related error
470+ }
471+ return false
472+ }
449473
450474// removeCached stops all background refreshes and deletes the connection
451475// info cache from the map of caches.
452476func (d * Dialer ) removeCached (
453477 ctx context.Context ,
454- i instance.ConnName , c connectionInfoCache , err error ,
478+ i instance.ConnName , c * monitoredCache , err error ,
455479) {
456480 d .logger .Debugf (
457481 ctx ,
@@ -461,8 +485,11 @@ func (d *Dialer) removeCached(
461485 )
462486 d .lock .Lock ()
463487 defer d .lock .Unlock ()
464- c .Close ()
465- delete (d .cache , createKey (i ))
488+ key := createKey (i )
489+ if cachedC , ok := d .cache [key ]; ok && cachedC == c {
490+ delete (d .cache , key )
491+ }
492+ c .connectionInfoCache .Close ()
466493}
467494
468495// validClientCert checks that the ephemeral client certificate retrieved from
@@ -504,7 +531,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
504531 }
505532 ci , err := c .ConnectionInfo (ctx )
506533 if err != nil {
507- d .removeCached (ctx , cn , c . connectionInfoCache , err )
534+ d .removeCached (ctx , cn , c , err )
508535 return "" , err
509536 }
510537 return ci .DBVersion , nil
@@ -528,17 +555,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
528555 }
529556 _ , err = c .ConnectionInfo (ctx )
530557 if err != nil {
531- d .removeCached (ctx , cn , c . connectionInfoCache , err )
558+ d .removeCached (ctx , cn , c , err )
532559 }
533560 return err
534561}
535562
536563// newInstrumentedConn initializes an instrumentedConn that on closing will
537564// decrement the number of open connects and record the result.
538- func newInstrumentedConn (conn net.Conn , closeFunc func (), dialerID , connName string ) * instrumentedConn {
565+ func newInstrumentedConn (conn net.Conn , closeFunc func (), errFunc func ( error ), dialerID , connName string ) * instrumentedConn {
539566 return & instrumentedConn {
540567 Conn : conn ,
541568 closeFunc : closeFunc ,
569+ errFunc : errFunc ,
542570 dialerID : dialerID ,
543571 connName : connName ,
544572 }
@@ -549,6 +577,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
549577type instrumentedConn struct {
550578 net.Conn
551579 closeFunc func ()
580+ errFunc func (error )
552581 mu sync.RWMutex
553582 closed bool
554583 dialerID string
@@ -561,6 +590,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
561590 bytesRead , err := i .Conn .Read (b )
562591 if err == nil {
563592 go trace .RecordBytesReceived (context .Background (), int64 (bytesRead ), i .connName , i .dialerID )
593+ } else {
594+ i .errFunc (err )
564595 }
565596 return bytesRead , err
566597}
@@ -571,6 +602,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
571602 bytesWritten , err := i .Conn .Write (b )
572603 if err == nil {
573604 go trace .RecordBytesSent (context .Background (), int64 (bytesWritten ), i .connName , i .dialerID )
605+ } else {
606+ i .errFunc (err )
574607 }
575608 return bytesWritten , err
576609}
0 commit comments