Skip to content

Commit 0ae60ad

Browse files
committed
fix: Refresh client cert when it is rejected by the server.
1 parent 3076568 commit 0ae60ad

File tree

4 files changed

+139
-46
lines changed

4 files changed

+139
-46
lines changed

dialer.go

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
418418
tlsConn := tls.Client(conn, ci.TLSConfig())
419419
err = tlsConn.HandshakeContext(ctx)
420420
if err != nil {
421+
// TLS handshake errors are fatal and require a refresh. Remove the instance
422+
// from the cache so that future calls to Dial() will block until the
423+
// certificate is refreshed successfully.
421424
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
422-
// refresh the instance info in case it caused the handshake failure
423-
c.ForceRefresh()
425+
d.removeCached(ctx, cn, c, err)
424426
_ = tlsConn.Close() // best effort close attempt
425427
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
426428
}
@@ -435,7 +437,22 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
435437
iConn := newInstrumentedConn(tlsConn, func() {
436438
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
437439
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
438-
}, d.dialerID, cn.String())
440+
},
441+
func(err error) {
442+
// ignore EOF
443+
if err == io.EOF {
444+
return
445+
}
446+
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
447+
if d.isTLSError(err) {
448+
// TLS handshake errors are fatal. Remove the instance from the cache
449+
// so that future calls to Dial() will block until the certificate
450+
// is refreshed successfully.
451+
d.removeCached(ctx, cn, c, err)
452+
_ = tlsConn.Close() // best effort close attempt
453+
}
454+
},
455+
d.dialerID, cn.String())
439456

440457
// If this connection was opened using a Domain Name, then store it for later
441458
// in case it needs to be forcibly closed.
@@ -446,12 +463,19 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
446463
}
447464
return iConn, nil
448465
}
466+
func (d *Dialer) isTLSError(err error) bool {
467+
if nErr, ok := err.(net.Error); ok {
468+
return !nErr.Timeout() && // it's a permanent net error
469+
strings.Contains(nErr.Error(), "tls") // it's a TLS-related error
470+
}
471+
return false
472+
}
449473

450474
// removeCached stops all background refreshes and deletes the connection
451475
// info cache from the map of caches.
452476
func (d *Dialer) removeCached(
453477
ctx context.Context,
454-
i instance.ConnName, c connectionInfoCache, err error,
478+
i instance.ConnName, c *monitoredCache, err error,
455479
) {
456480
d.logger.Debugf(
457481
ctx,
@@ -461,8 +485,11 @@ func (d *Dialer) removeCached(
461485
)
462486
d.lock.Lock()
463487
defer d.lock.Unlock()
464-
c.Close()
465-
delete(d.cache, createKey(i))
488+
key := createKey(i)
489+
if cachedC, ok := d.cache[key]; ok && cachedC == c {
490+
delete(d.cache, key)
491+
}
492+
c.connectionInfoCache.Close()
466493
}
467494

468495
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -504,7 +531,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
504531
}
505532
ci, err := c.ConnectionInfo(ctx)
506533
if err != nil {
507-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
534+
d.removeCached(ctx, cn, c, err)
508535
return "", err
509536
}
510537
return ci.DBVersion, nil
@@ -528,17 +555,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
528555
}
529556
_, err = c.ConnectionInfo(ctx)
530557
if err != nil {
531-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
558+
d.removeCached(ctx, cn, c, err)
532559
}
533560
return err
534561
}
535562

536563
// newInstrumentedConn initializes an instrumentedConn that on closing will
537564
// decrement the number of open connects and record the result.
538-
func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn {
565+
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
539566
return &instrumentedConn{
540567
Conn: conn,
541568
closeFunc: closeFunc,
569+
errFunc: errFunc,
542570
dialerID: dialerID,
543571
connName: connName,
544572
}
@@ -549,6 +577,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
549577
type instrumentedConn struct {
550578
net.Conn
551579
closeFunc func()
580+
errFunc func(error)
552581
mu sync.RWMutex
553582
closed bool
554583
dialerID string
@@ -561,6 +590,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
561590
bytesRead, err := i.Conn.Read(b)
562591
if err == nil {
563592
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
593+
} else {
594+
i.errFunc(err)
564595
}
565596
return bytesRead, err
566597
}
@@ -571,6 +602,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
571602
bytesWritten, err := i.Conn.Write(b)
572603
if err == nil {
573604
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
605+
} else {
606+
i.errFunc(err)
574607
}
575608
return bytesWritten, err
576609
}

dialer_test.go

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
11751175
}
11761176
}
11771177

1178-
func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
1178+
func TestDialerRefreshesAfterRotateClientCA(t *testing.T) {
11791179
inst := mock.NewFakeCSQLInstanceWithSan(
11801180
"my-project", "my-region", "my-instance", []string{"db.example.com"},
11811181
mock.WithDNS("db.example.com"),
@@ -1210,14 +1210,13 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
12101210
cancel1()
12111211

12121212
mock.RotateClientCA(inst)
1213-
time.Sleep(2 * time.Second)
12141213

1215-
// Recreate the instance, which generates new server certificates
12161214
// Start the server with new certificates
12171215
cancel2 := mock.StartServerProxy(t, inst)
12181216
defer cancel2()
12191217

12201218
// Dial a second time. We expect no error on dial, but TLS error on read.
1219+
t.Log("Second attempt should fail...")
12211220
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
12221221
if err != nil {
12231222
t.Fatal("Should be no certificate error after, got ", err)
@@ -1230,8 +1229,66 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
12301229
} else {
12311230
t.Fatal("Want read error, got no error")
12321231
}
1232+
t.Log("Second attempt done")
1233+
1234+
// Dial again. This should complete after the refresh.
1235+
t.Log("Third attempt...")
1236+
testSuccessfulDial(
1237+
context.Background(), t, d,
1238+
"my-project:my-region:my-instance",
1239+
)
1240+
t.Log("Third attempt OK.")
1241+
}
1242+
1243+
func TestDialerRefreshesAfterRotateServerCA(t *testing.T) {
1244+
inst := mock.NewFakeCSQLInstanceWithSan(
1245+
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1246+
mock.WithDNS("db.example.com"),
1247+
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1248+
)
1249+
1250+
d := setupDialer(t, setupConfig{
1251+
skipServer: true,
1252+
testInstance: inst,
1253+
reqs: []*mock.Request{
1254+
mock.InstanceGetSuccess(inst, 2),
1255+
mock.CreateEphemeralSuccess(inst, 2),
1256+
},
1257+
dialerOptions: []Option{
1258+
WithTokenSource(mock.EmptyTokenSource{}),
1259+
WithDebugLogger(&dialerTestLogger{t: t}),
1260+
WithLazyRefresh(),
1261+
// Note: this succeeds with lazy refresh, but fails with lazy.
1262+
// because dialer.ForceRefresh does not block connections while the
1263+
// refresh is in progress.
1264+
},
1265+
})
1266+
cancel1 := mock.StartServerProxy(t, inst)
1267+
t.Log("First attempt...")
1268+
testSuccessfulDial(
1269+
context.Background(), t, d,
1270+
"my-project:my-region:my-instance",
1271+
)
1272+
t.Log("First attempt OK. Resetting client cert.")
1273+
1274+
// Close the server
1275+
cancel1()
1276+
1277+
mock.RotateCA(inst)
1278+
1279+
// Start the server with new certificates
1280+
cancel2 := mock.StartServerProxy(t, inst)
1281+
defer cancel2()
1282+
1283+
// Dial a second time. We expect no error on dial, but TLS error on read.
1284+
t.Log("Second attempt should fail...")
1285+
_, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1286+
if err != nil {
1287+
t.Log("Got error on dial as expected.", err)
1288+
} else {
1289+
t.Fatal("Want dial error, got no error")
1290+
}
12331291

1234-
time.Sleep(2 * time.Second)
12351292
// Dial again. This should occur after the refresh has completed.
12361293
t.Log("Third attempt...")
12371294
testSuccessfulDial(

internal/mock/cloudsql.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,14 @@ func (f FakeCSQLInstance) String() string {
7474
func (f FakeCSQLInstance) serverCACert() ([]byte, error) {
7575
if f.signer != nil {
7676
return f.signer(f.Cert, f.Key)
77-
} else {
78-
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
79-
// legacy server mode, return only the server cert
80-
return toPEMFormat(f.certs.serverCert)
81-
}
82-
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
8377
}
78+
79+
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
80+
// legacy server mode, return only the server cert
81+
return toPEMFormat(f.certs.serverCert)
82+
}
83+
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
84+
8485
}
8586

8687
// ClientCert creates an ephemeral client certificate signed with the Cloud SQL
@@ -297,10 +298,12 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
297298
}
298299
}
299300

301+
// RotateCA rotates all CA certificates in the instance.
300302
func RotateCA(inst FakeCSQLInstance) {
301303
inst.certs.rotateCA()
302304
}
303305

306+
// RotateClientCA rotates all only the client CA certificates in the instance.
304307
func RotateClientCA(inst FakeCSQLInstance) {
305308
inst.certs.rotateClientCA()
306309
}

internal/mock/sqladmin.go

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,38 +102,38 @@ func (r *Request) matches(hR *http.Request) bool {
102102
//
103103
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get
104104
func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
105-
var ips []*sqladmin.IpMapping
106-
for ipType, addr := range i.ipAddrs {
107-
if ipType == "PUBLIC" {
108-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
109-
continue
110-
}
111-
if ipType == "PRIVATE" {
112-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
113-
}
114-
}
115-
116-
certBytes, err := i.serverCACert()
117-
if err != nil {
118-
panic(err)
119-
}
120-
121-
db := &sqladmin.ConnectSettings{
122-
BackendType: i.backendType,
123-
DatabaseVersion: i.dbVersion,
124-
DnsName: i.DNSName,
125-
IpAddresses: ips,
126-
Region: i.region,
127-
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
128-
PscEnabled: i.pscEnabled,
129-
ServerCaMode: i.serverCAMode,
130-
}
131-
132105
r := &Request{
133106
reqMethod: http.MethodGet,
134107
reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
135108
reqCt: ct,
136109
handle: func(resp http.ResponseWriter, _ *http.Request) {
110+
var ips []*sqladmin.IpMapping
111+
for ipType, addr := range i.ipAddrs {
112+
if ipType == "PUBLIC" {
113+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
114+
continue
115+
}
116+
if ipType == "PRIVATE" {
117+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
118+
}
119+
}
120+
121+
certBytes, err := i.serverCACert()
122+
if err != nil {
123+
panic(err)
124+
}
125+
126+
db := &sqladmin.ConnectSettings{
127+
BackendType: i.backendType,
128+
DatabaseVersion: i.dbVersion,
129+
DnsName: i.DNSName,
130+
IpAddresses: ips,
131+
Region: i.region,
132+
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
133+
PscEnabled: i.pscEnabled,
134+
ServerCaMode: i.serverCAMode,
135+
}
136+
137137
b, err := db.MarshalJSON()
138138
if err != nil {
139139
http.Error(resp, err.Error(), http.StatusInternalServerError)

0 commit comments

Comments
 (0)