Skip to content

Commit fb0dcbf

Browse files
committed
fix: Refresh client cert when it is rejected by the server.
1 parent fbba169 commit fb0dcbf

File tree

3 files changed

+71
-36
lines changed

3 files changed

+71
-36
lines changed

dialer.go

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
418418
tlsConn := tls.Client(conn, ci.TLSConfig())
419419
err = tlsConn.HandshakeContext(ctx)
420420
if err != nil {
421+
// TLS handshake errors are fatal and require a refresh. Remove the instance
422+
// from the cache so that future calls to Dial() will block until the
423+
// certificate is refreshed successfully.
421424
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
422-
// refresh the instance info in case it caused the handshake failure
423-
c.ForceRefresh()
425+
d.removeCached(ctx, cn, c, err)
424426
_ = tlsConn.Close() // best effort close attempt
425427
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
426428
}
@@ -435,7 +437,22 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
435437
iConn := newInstrumentedConn(tlsConn, func() {
436438
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
437439
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
438-
}, d.dialerID, cn.String())
440+
},
441+
func(err error) {
442+
// ignore EOF
443+
if err == io.EOF {
444+
return
445+
}
446+
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
447+
if d.isTLSError(err) {
448+
// TLS handshake errors are fatal. Remove the instance from the cache
449+
// so that future calls to Dial() will block until the certificate
450+
// is refreshed successfully.
451+
d.removeCached(ctx, cn, c, err)
452+
_ = tlsConn.Close() // best effort close attempt
453+
}
454+
},
455+
d.dialerID, cn.String())
439456

440457
// If this connection was opened using a Domain Name, then store it for later
441458
// in case it needs to be forcibly closed.
@@ -446,12 +463,19 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
446463
}
447464
return iConn, nil
448465
}
466+
func (d *Dialer) isTLSError(err error) bool {
467+
if nErr, ok := err.(net.Error); ok {
468+
return !nErr.Timeout() && // it's a permanent net error
469+
strings.Contains(nErr.Error(), "tls") // it's a TLS-related error
470+
}
471+
return false
472+
}
449473

450474
// removeCached stops all background refreshes and deletes the connection
451475
// info cache from the map of caches.
452476
func (d *Dialer) removeCached(
453477
ctx context.Context,
454-
i instance.ConnName, c connectionInfoCache, err error,
478+
i instance.ConnName, c *monitoredCache, err error,
455479
) {
456480
d.logger.Debugf(
457481
ctx,
@@ -461,8 +485,11 @@ func (d *Dialer) removeCached(
461485
)
462486
d.lock.Lock()
463487
defer d.lock.Unlock()
464-
c.Close()
465-
delete(d.cache, createKey(i))
488+
key := createKey(i)
489+
if cachedC, ok := d.cache[key]; ok && cachedC == c {
490+
delete(d.cache, key)
491+
}
492+
c.connectionInfoCache.Close()
466493
}
467494

468495
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -504,7 +531,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
504531
}
505532
ci, err := c.ConnectionInfo(ctx)
506533
if err != nil {
507-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
534+
d.removeCached(ctx, cn, c, err)
508535
return "", err
509536
}
510537
return ci.DBVersion, nil
@@ -528,17 +555,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
528555
}
529556
_, err = c.ConnectionInfo(ctx)
530557
if err != nil {
531-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
558+
d.removeCached(ctx, cn, c, err)
532559
}
533560
return err
534561
}
535562

536563
// newInstrumentedConn initializes an instrumentedConn that on closing will
537564
// decrement the number of open connects and record the result.
538-
func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn {
565+
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
539566
return &instrumentedConn{
540567
Conn: conn,
541568
closeFunc: closeFunc,
569+
errFunc: errFunc,
542570
dialerID: dialerID,
543571
connName: connName,
544572
}
@@ -549,6 +577,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
549577
type instrumentedConn struct {
550578
net.Conn
551579
closeFunc func()
580+
errFunc func(error)
552581
mu sync.RWMutex
553582
closed bool
554583
dialerID string
@@ -561,6 +590,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
561590
bytesRead, err := i.Conn.Read(b)
562591
if err == nil {
563592
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
593+
} else {
594+
i.errFunc(err)
564595
}
565596
return bytesRead, err
566597
}
@@ -571,6 +602,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
571602
bytesWritten, err := i.Conn.Write(b)
572603
if err == nil {
573604
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
605+
} else {
606+
i.errFunc(err)
574607
}
575608
return bytesWritten, err
576609
}

internal/mock/cloudsql.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ func (f FakeCSQLInstance) serverCACert() ([]byte, error) {
7575
if f.signer != nil {
7676
return f.signer(f.Cert, f.Key)
7777
}
78+
7879
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
7980
// legacy server mode, return only the server cert
8081
return toPEMFormat(f.certs.serverCert)
8182
}
8283
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
84+
8385
}
8486

8587
// ClientCert creates an ephemeral client certificate signed with the Cloud SQL

internal/mock/sqladmin.go

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,38 +102,38 @@ func (r *Request) matches(hR *http.Request) bool {
102102
//
103103
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get
104104
func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
105-
var ips []*sqladmin.IpMapping
106-
for ipType, addr := range i.ipAddrs {
107-
if ipType == "PUBLIC" {
108-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
109-
continue
110-
}
111-
if ipType == "PRIVATE" {
112-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
113-
}
114-
}
115-
116-
certBytes, err := i.serverCACert()
117-
if err != nil {
118-
panic(err)
119-
}
120-
121-
db := &sqladmin.ConnectSettings{
122-
BackendType: i.backendType,
123-
DatabaseVersion: i.dbVersion,
124-
DnsName: i.DNSName,
125-
IpAddresses: ips,
126-
Region: i.region,
127-
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
128-
PscEnabled: i.pscEnabled,
129-
ServerCaMode: i.serverCAMode,
130-
}
131-
132105
r := &Request{
133106
reqMethod: http.MethodGet,
134107
reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
135108
reqCt: ct,
136109
handle: func(resp http.ResponseWriter, _ *http.Request) {
110+
var ips []*sqladmin.IpMapping
111+
for ipType, addr := range i.ipAddrs {
112+
if ipType == "PUBLIC" {
113+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
114+
continue
115+
}
116+
if ipType == "PRIVATE" {
117+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
118+
}
119+
}
120+
121+
certBytes, err := i.serverCACert()
122+
if err != nil {
123+
panic(err)
124+
}
125+
126+
db := &sqladmin.ConnectSettings{
127+
BackendType: i.backendType,
128+
DatabaseVersion: i.dbVersion,
129+
DnsName: i.DNSName,
130+
IpAddresses: ips,
131+
Region: i.region,
132+
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
133+
PscEnabled: i.pscEnabled,
134+
ServerCaMode: i.serverCAMode,
135+
}
136+
137137
b, err := db.MarshalJSON()
138138
if err != nil {
139139
http.Error(resp, err.Error(), http.StatusInternalServerError)

0 commit comments

Comments
 (0)