Skip to content

Commit c22e2d4

Browse files
authored
fix: Refresh client cert when CA rotation (#934)
When client CA is rotated, this can cause TLS read errors after the Dialer.Dial() has returned. The certificate should be refreshed. This adds logic to the dialer to refresh the connection cache if a TLS error occurs after net.Conn is returned to the database driver. Fixes #932
1 parent ae376f1 commit c22e2d4

File tree

3 files changed

+206
-39
lines changed

3 files changed

+206
-39
lines changed

dialer.go

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
419419
tlsConn := tls.Client(conn, ci.TLSConfig())
420420
err = tlsConn.HandshakeContext(ctx)
421421
if err != nil {
422+
// TLS handshake errors are fatal and require a refresh. Remove the instance
423+
// from the cache so that future calls to Dial() will block until the
424+
// certificate is refreshed successfully.
422425
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
423-
// refresh the instance info in case it caused the handshake failure
424-
c.ForceRefresh()
426+
d.removeCached(ctx, cn, c, err)
425427
_ = tlsConn.Close() // best effort close attempt
426428
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
427429
}
@@ -433,10 +435,26 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
433435
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
434436
}()
435437

436-
iConn := newInstrumentedConn(tlsConn, func() {
438+
closeFunc := func() {
437439
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
438440
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
439-
}, d.dialerID, cn.String())
441+
}
442+
errFunc := func(err error) {
443+
// io.EOF occurs when the server closes the connection. This is safe to
444+
// ignore.
445+
if err == io.EOF {
446+
return
447+
}
448+
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
449+
if d.isTLSError(err) {
450+
// TLS handshake errors are fatal. Remove the instance from the cache
451+
// so that future calls to Dial() will block until the certificate
452+
// is refreshed successfully.
453+
d.removeCached(ctx, cn, c, err)
454+
_ = tlsConn.Close() // best effort close attempt
455+
}
456+
}
457+
iConn := newInstrumentedConn(tlsConn, closeFunc, errFunc, d.dialerID, cn.String())
440458

441459
// If this connection was opened using a Domain Name, then store it for later
442460
// in case it needs to be forcibly closed.
@@ -447,23 +465,39 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
447465
}
448466
return iConn, nil
449467
}
468+
func (d *Dialer) isTLSError(err error) bool {
469+
if nErr, ok := err.(net.Error); ok {
470+
return !nErr.Timeout() && // it's a permanent net error
471+
strings.Contains(nErr.Error(), "tls") // it's a TLS-related error
472+
}
473+
return false
474+
}
450475

451-
// removeCached stops all background refreshes and deletes the connection
452-
// info cache from the map of caches.
476+
// removeCached stops all background refreshes, closes open sockets, and deletes
477+
// the cache entry.
453478
func (d *Dialer) removeCached(
454479
ctx context.Context,
455-
i instance.ConnName, c connectionInfoCache, err error,
480+
i instance.ConnName, c *monitoredCache, err error,
456481
) {
457482
d.logger.Debugf(
458483
ctx,
459484
"[%v] Removing connection info from cache: %v",
460485
i.String(),
461486
err,
462487
)
488+
489+
// If this instance of monitoredCache is still in the cache, remove it.
490+
// If this instance was already removed from the cache or
491+
// if *a separate goroutine* replaced it with a new instance, do nothing.
492+
key := createKey(i)
463493
d.lock.Lock()
464-
defer d.lock.Unlock()
494+
if cachedC, ok := d.cache[key]; ok && cachedC == c {
495+
delete(d.cache, key)
496+
}
497+
d.lock.Unlock()
498+
499+
// Close the monitoredCache, this call is idempotent.
465500
c.Close()
466-
delete(d.cache, createKey(i))
467501
}
468502

469503
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -505,7 +539,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
505539
}
506540
ci, err := c.ConnectionInfo(ctx)
507541
if err != nil {
508-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
542+
d.removeCached(ctx, cn, c, err)
509543
return "", err
510544
}
511545
return ci.DBVersion, nil
@@ -529,17 +563,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
529563
}
530564
_, err = c.ConnectionInfo(ctx)
531565
if err != nil {
532-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
566+
d.removeCached(ctx, cn, c, err)
533567
}
534568
return err
535569
}
536570

537571
// newInstrumentedConn initializes an instrumentedConn that on closing will
538572
// decrement the number of open connects and record the result.
539-
func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn {
573+
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
540574
return &instrumentedConn{
541575
Conn: conn,
542576
closeFunc: closeFunc,
577+
errFunc: errFunc,
543578
dialerID: dialerID,
544579
connName: connName,
545580
}
@@ -550,6 +585,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
550585
type instrumentedConn struct {
551586
net.Conn
552587
closeFunc func()
588+
errFunc func(error)
553589
mu sync.RWMutex
554590
closed bool
555591
dialerID string
@@ -562,6 +598,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
562598
bytesRead, err := i.Conn.Read(b)
563599
if err == nil {
564600
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
601+
} else {
602+
i.errFunc(err)
565603
}
566604
return bytesRead, err
567605
}
@@ -572,6 +610,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
572610
bytesWritten, err := i.Conn.Write(b)
573611
if err == nil {
574612
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
613+
} else {
614+
i.errFunc(err)
575615
}
576616
return bytesWritten, err
577617
}

dialer_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,3 +1174,127 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
11741174
t.Fatal("want error containing `tls: failed to verify certificate`. Got: ", err)
11751175
}
11761176
}
1177+
1178+
func TestDialerRefreshesAfterRotateCACerts(t *testing.T) {
1179+
tcs := []struct {
1180+
desc string
1181+
rotateClientCA bool
1182+
wantErrorOnDial bool
1183+
wantErrorOnRead bool
1184+
useLazyRefresh bool
1185+
}{
1186+
{
1187+
desc: "Rotating Client CA causes error on read, then refresh",
1188+
rotateClientCA: true,
1189+
wantErrorOnRead: true,
1190+
},
1191+
{
1192+
desc: "Rotating all CAs causes error on dial, then refresh",
1193+
wantErrorOnDial: true,
1194+
wantErrorOnRead: false,
1195+
},
1196+
{
1197+
desc: "Rotating Client CA with lazy refresh causes error on read",
1198+
rotateClientCA: true,
1199+
wantErrorOnRead: true,
1200+
useLazyRefresh: true,
1201+
},
1202+
{
1203+
desc: "Rotating all CAs with lazy refresh causes error on dial",
1204+
wantErrorOnDial: true,
1205+
useLazyRefresh: true,
1206+
},
1207+
}
1208+
for _, tc := range tcs {
1209+
t.Run(tc.desc, func(t *testing.T) {
1210+
inst := mock.NewFakeCSQLInstanceWithSan(
1211+
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1212+
mock.WithDNS("db.example.com"),
1213+
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1214+
)
1215+
1216+
opts := []Option{
1217+
WithTokenSource(mock.EmptyTokenSource{}),
1218+
WithDebugLogger(&dialerTestLogger{t: t}),
1219+
}
1220+
if tc.useLazyRefresh {
1221+
opts = append(opts, WithLazyRefresh())
1222+
}
1223+
1224+
d := setupDialer(t, setupConfig{
1225+
skipServer: true,
1226+
testInstance: inst,
1227+
reqs: []*mock.Request{
1228+
mock.InstanceGetSuccess(inst, 2),
1229+
mock.CreateEphemeralSuccess(inst, 2),
1230+
},
1231+
dialerOptions: opts,
1232+
})
1233+
cancel1 := mock.StartServerProxy(t, inst)
1234+
t.Log("First attempt...")
1235+
testSuccessfulDial(
1236+
context.Background(), t, d,
1237+
"my-project:my-region:my-instance",
1238+
)
1239+
t.Log("First attempt OK. Resetting client cert.")
1240+
1241+
// Close the server
1242+
cancel1()
1243+
1244+
if tc.rotateClientCA {
1245+
mock.RotateClientCA(inst)
1246+
} else {
1247+
mock.RotateCA(inst)
1248+
}
1249+
1250+
// Start the server with new certificates
1251+
cancel2 := mock.StartServerProxy(t, inst)
1252+
defer cancel2()
1253+
1254+
// Dial a second time.
1255+
t.Log("Second attempt should fail...")
1256+
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1257+
if err != nil {
1258+
if tc.wantErrorOnDial {
1259+
t.Logf("got error on dial as expected: %v", err)
1260+
} else {
1261+
t.Fatalf("want no dial error, got: %v", err)
1262+
}
1263+
} else if tc.wantErrorOnDial {
1264+
t.Fatal("want dial error, got no error")
1265+
}
1266+
1267+
// If no error expected on dial, then attempt to read.
1268+
if !tc.wantErrorOnDial {
1269+
_, err = io.ReadAll(conn)
1270+
if err != nil {
1271+
if tc.wantErrorOnRead {
1272+
t.Logf("got error on read as expected: %v", err)
1273+
} else {
1274+
t.Fatalf("want no read error, got: %v", err)
1275+
}
1276+
} else if tc.wantErrorOnRead {
1277+
t.Fatal("want read error, got no error")
1278+
}
1279+
}
1280+
t.Log("Second attempt done")
1281+
1282+
// Dial again. This should complete after the refresh.
1283+
t.Log("Third attempt...")
1284+
testSuccessfulDial(
1285+
context.Background(), t, d,
1286+
"my-project:my-region:my-instance",
1287+
)
1288+
t.Log("Third attempt OK.")
1289+
})
1290+
}
1291+
1292+
}
1293+
1294+
type dialerTestLogger struct {
1295+
t *testing.T
1296+
}
1297+
1298+
func (l *dialerTestLogger) Debugf(f string, args ...interface{}) {
1299+
l.t.Logf(f, args...)
1300+
}

internal/mock/sqladmin.go

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,38 +102,41 @@ func (r *Request) matches(hR *http.Request) bool {
102102
//
103103
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get
104104
func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
105-
var ips []*sqladmin.IpMapping
106-
for ipType, addr := range i.ipAddrs {
107-
if ipType == "PUBLIC" {
108-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
109-
continue
110-
}
111-
if ipType == "PRIVATE" {
112-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
113-
}
114-
}
115-
116-
certBytes, err := i.serverCACert()
117-
if err != nil {
118-
panic(err)
119-
}
120-
121-
db := &sqladmin.ConnectSettings{
122-
BackendType: i.backendType,
123-
DatabaseVersion: i.dbVersion,
124-
DnsName: i.DNSName,
125-
IpAddresses: ips,
126-
Region: i.region,
127-
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
128-
PscEnabled: i.pscEnabled,
129-
ServerCaMode: i.serverCAMode,
130-
}
131-
132105
r := &Request{
133106
reqMethod: http.MethodGet,
134107
reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
135108
reqCt: ct,
136109
handle: func(resp http.ResponseWriter, _ *http.Request) {
110+
// Calculate the response when the request occurs the response contains
111+
// up-to-date data stored in the FakeCSQLInstance.
112+
// This is especially important for the i.serverCACert().
113+
var ips []*sqladmin.IpMapping
114+
for ipType, addr := range i.ipAddrs {
115+
if ipType == "PUBLIC" {
116+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
117+
continue
118+
}
119+
if ipType == "PRIVATE" {
120+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
121+
}
122+
}
123+
124+
certBytes, err := i.serverCACert()
125+
if err != nil {
126+
panic(err)
127+
}
128+
129+
db := &sqladmin.ConnectSettings{
130+
BackendType: i.backendType,
131+
DatabaseVersion: i.dbVersion,
132+
DnsName: i.DNSName,
133+
IpAddresses: ips,
134+
Region: i.region,
135+
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
136+
PscEnabled: i.pscEnabled,
137+
ServerCaMode: i.serverCAMode,
138+
}
139+
137140
b, err := db.MarshalJSON()
138141
if err != nil {
139142
http.Error(resp, err.Error(), http.StatusInternalServerError)

0 commit comments

Comments
 (0)