@@ -85,7 +85,7 @@ type keyGenerator struct {
8585// - generate an RSA key lazily when it's requested, or
8686// - (default) immediately generate an RSA key as part of the initializer.
8787func newKeyGenerator (
88- k * rsa.PrivateKey , lazy bool , genFunc func () (* rsa.PrivateKey , error ),
88+ k * rsa.PrivateKey , lazy bool , genFunc func () (* rsa.PrivateKey , error ),
8989) (* keyGenerator , error ) {
9090 g := & keyGenerator {genFunc : genFunc }
9191 switch {
@@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
418418 tlsConn := tls .Client (conn , ci .TLSConfig ())
419419 err = tlsConn .HandshakeContext (ctx )
420420 if err != nil {
421+ // TLS handshake errors are fatal and require a refresh. Remove the instance
422+ // from the cache so that future calls to Dial() will block until the
423+ // certificate is refreshed successfully.
421424 d .logger .Debugf (ctx , "[%v] TLS handshake failed: %v" , cn .String (), err )
422- // refresh the instance info in case it caused the handshake failure
423- c .ForceRefresh ()
425+ d .removeCached (ctx , cn , c , err )
424426 _ = tlsConn .Close () // best effort close attempt
425427 return nil , errtype .NewDialError ("handshake failed" , cn .String (), err )
426428 }
@@ -435,7 +437,21 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
435437 iConn := newInstrumentedConn (tlsConn , func () {
436438 n := atomic .AddUint64 (c .openConnsCount , ^ uint64 (0 )) // c.openConnsCount = c.openConnsCount - 1
437439 trace .RecordOpenConnections (context .Background (), int64 (n ), d .dialerID , cn .String ())
438- }, d .dialerID , cn .String ())
440+ },
441+ func (err error ) {
442+ // ignore EOF
443+ if err == io .EOF {
444+ return
445+ }
446+ d .logger .Debugf (ctx , "[%v] IO Error on Read or Write: %v" , cn .String (), err )
447+ if d .isTLSError (err ) {
448+ // TLS handshake errors are fatal. Remove the instance from the cache
449+ // so that future calls to Dial() will block until the certificate
450+ // is refreshed successfully.
451+ d .removeCached (ctx , cn , c , err )
452+ }
453+ },
454+ d .dialerID , cn .String ())
439455
440456 // If this connection was opened using a Domain Name, then store it for later
441457 // in case it needs to be forcibly closed.
@@ -446,12 +462,18 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
446462 }
447463 return iConn , nil
448464}
465+ func (d * Dialer ) isTLSError (err error ) bool {
466+ if nErr , ok := err .(net.Error ); ok {
467+ return ! nErr .Timeout ()
468+ }
469+ return false
470+ }
449471
450472// removeCached stops all background refreshes and deletes the connection
451473// info cache from the map of caches.
452474func (d * Dialer ) removeCached (
453- ctx context.Context ,
454- i instance.ConnName , c connectionInfoCache , err error ,
475+ ctx context.Context ,
476+ i instance.ConnName , c connectionInfoCache , err error ,
455477) {
456478 d .logger .Debugf (
457479 ctx ,
@@ -461,16 +483,19 @@ func (d *Dialer) removeCached(
461483 )
462484 d .lock .Lock ()
463485 defer d .lock .Unlock ()
464- c .Close ()
465- delete (d .cache , createKey (i ))
486+ key := createKey (i )
487+ if cachedC , ok := d .cache [key ]; ok && cachedC == c {
488+ c .Close ()
489+ delete (d .cache , key )
490+ }
466491}
467492
468493// validClientCert checks that the ephemeral client certificate retrieved from
469494// the cache is unexpired. The time comparisons strip the monotonic clock value
470495// to ensure an accurate result, even after laptop sleep.
471496func validClientCert (
472- ctx context.Context , cn instance.ConnName ,
473- l debug.ContextLogger , expiration time.Time ,
497+ ctx context.Context , cn instance.ConnName ,
498+ l debug.ContextLogger , expiration time.Time ,
474499) bool {
475500 // Use UTC() to strip monotonic clock value to guard against inaccurate
476501 // comparisons, especially after laptop sleep.
@@ -535,10 +560,11 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
535560
536561// newInstrumentedConn initializes an instrumentedConn that on closing will
537562// decrement the number of open connects and record the result.
538- func newInstrumentedConn (conn net.Conn , closeFunc func (), dialerID , connName string ) * instrumentedConn {
563+ func newInstrumentedConn (conn net.Conn , closeFunc func (), errFunc func ( error ), dialerID , connName string ) * instrumentedConn {
539564 return & instrumentedConn {
540565 Conn : conn ,
541566 closeFunc : closeFunc ,
567+ errFunc : errFunc ,
542568 dialerID : dialerID ,
543569 connName : connName ,
544570 }
@@ -549,6 +575,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
549575type instrumentedConn struct {
550576 net.Conn
551577 closeFunc func ()
578+ errFunc func (error )
552579 mu sync.RWMutex
553580 closed bool
554581 dialerID string
@@ -561,6 +588,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
561588 bytesRead , err := i .Conn .Read (b )
562589 if err == nil {
563590 go trace .RecordBytesReceived (context .Background (), int64 (bytesRead ), i .connName , i .dialerID )
591+ } else {
592+ i .errFunc (err )
564593 }
565594 return bytesRead , err
566595}
@@ -571,6 +600,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
571600 bytesWritten , err := i .Conn .Write (b )
572601 if err == nil {
573602 go trace .RecordBytesSent (context .Background (), int64 (bytesWritten ), i .connName , i .dialerID )
603+ } else {
604+ i .errFunc (err )
574605 }
575606 return bytesWritten , err
576607}
@@ -634,7 +665,7 @@ func createKey(cn instance.ConnName) cacheKey {
634665// connection info Cache in a threadsafe way. It will create a new cache,
635666// modify the existing one, or leave it unchanged as needed.
636667func (d * Dialer ) connectionInfoCache (
637- ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
668+ ctx context.Context , cn instance.ConnName , useIAMAuthN * bool ,
638669) (* monitoredCache , error ) {
639670 k := createKey (cn )
640671
0 commit comments