Skip to content

Commit eb3a6db

Browse files
committed
fix: Refresh client cert when it is rejected by the server.
1 parent e7649f9 commit eb3a6db

File tree

3 files changed

+179
-147
lines changed

3 files changed

+179
-147
lines changed

dialer.go

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
419419
tlsConn := tls.Client(conn, ci.TLSConfig())
420420
err = tlsConn.HandshakeContext(ctx)
421421
if err != nil {
422+
// TLS handshake errors are fatal and require a refresh. Remove the instance
423+
// from the cache so that future calls to Dial() will block until the
424+
// certificate is refreshed successfully.
422425
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
423-
// refresh the instance info in case it caused the handshake failure
424-
c.ForceRefresh()
426+
d.removeCached(ctx, cn, c, err)
425427
_ = tlsConn.Close() // best effort close attempt
426428
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
427429
}
@@ -433,10 +435,26 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
433435
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
434436
}()
435437

436-
iConn := newInstrumentedConn(tlsConn, func() {
438+
closeFunc := func() {
437439
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
438440
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
439-
}, d.dialerID, cn.String())
441+
}
442+
errFunc := func(err error) {
443+
// io.EOF occurs when the server closes the connection. This is safe to
444+
// ignore.
445+
if err == io.EOF {
446+
return
447+
}
448+
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
449+
if d.isTLSError(err) {
450+
// TLS handshake errors are fatal. Remove the instance from the cache
451+
// so that future calls to Dial() will block until the certificate
452+
// is refreshed successfully.
453+
d.removeCached(ctx, cn, c, err)
454+
_ = tlsConn.Close() // best effort close attempt
455+
}
456+
}
457+
iConn := newInstrumentedConn(tlsConn, closeFunc, errFunc, d.dialerID, cn.String())
440458

441459
// If this connection was opened using a Domain Name, then store it for later
442460
// in case it needs to be forcibly closed.
@@ -447,23 +465,39 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
447465
}
448466
return iConn, nil
449467
}
468+
func (d *Dialer) isTLSError(err error) bool {
469+
if nErr, ok := err.(net.Error); ok {
470+
return !nErr.Timeout() && // it's a permanent net error
471+
strings.Contains(nErr.Error(), "tls") // it's a TLS-related error
472+
}
473+
return false
474+
}
450475

451-
// removeCached stops all background refreshes and deletes the connection
452-
// info cache from the map of caches.
476+
// removeCached stops all background refreshes, closes open sockets, and deletes
477+
// the cache entry.
453478
func (d *Dialer) removeCached(
454479
ctx context.Context,
455-
i instance.ConnName, c connectionInfoCache, err error,
480+
i instance.ConnName, c *monitoredCache, err error,
456481
) {
457482
d.logger.Debugf(
458483
ctx,
459484
"[%v] Removing connection info from cache: %v",
460485
i.String(),
461486
err,
462487
)
488+
489+
// If this instance of monitoredCache is still in the cache, remove it.
490+
// If this instance was already removed from the cache or replaced with
491+
// a new instance, do nothing.
492+
key := createKey(i)
463493
d.lock.Lock()
464-
defer d.lock.Unlock()
494+
if cachedC, ok := d.cache[key]; ok && cachedC == c {
495+
delete(d.cache, key)
496+
}
497+
d.lock.Unlock()
498+
499+
// Close the monitoredCache, this call is idempotent.
465500
c.Close()
466-
delete(d.cache, createKey(i))
467501
}
468502

469503
// validClientCert checks that the ephemeral client certificate retrieved from
@@ -505,7 +539,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
505539
}
506540
ci, err := c.ConnectionInfo(ctx)
507541
if err != nil {
508-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
542+
d.removeCached(ctx, cn, c, err)
509543
return "", err
510544
}
511545
return ci.DBVersion, nil
@@ -529,17 +563,18 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
529563
}
530564
_, err = c.ConnectionInfo(ctx)
531565
if err != nil {
532-
d.removeCached(ctx, cn, c.connectionInfoCache, err)
566+
d.removeCached(ctx, cn, c, err)
533567
}
534568
return err
535569
}
536570

537571
// newInstrumentedConn initializes an instrumentedConn that on closing will
538572
// decrement the number of open connects and record the result.
539-
func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn {
573+
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
540574
return &instrumentedConn{
541575
Conn: conn,
542576
closeFunc: closeFunc,
577+
errFunc: errFunc,
543578
dialerID: dialerID,
544579
connName: connName,
545580
}
@@ -550,6 +585,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
550585
type instrumentedConn struct {
551586
net.Conn
552587
closeFunc func()
588+
errFunc func(error)
553589
mu sync.RWMutex
554590
closed bool
555591
dialerID string
@@ -562,6 +598,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
562598
bytesRead, err := i.Conn.Read(b)
563599
if err == nil {
564600
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
601+
} else {
602+
i.errFunc(err)
565603
}
566604
return bytesRead, err
567605
}
@@ -572,6 +610,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
572610
bytesWritten, err := i.Conn.Write(b)
573611
if err == nil {
574612
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
613+
} else {
614+
i.errFunc(err)
575615
}
576616
return bytesWritten, err
577617
}

dialer_test.go

Lines changed: 100 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,127 +1175,119 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
11751175
}
11761176
}
11771177

1178-
func TestDialerRefreshesAfterRotateClientCA(t *testing.T) {
1179-
inst := mock.NewFakeCSQLInstanceWithSan(
1180-
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1181-
mock.WithDNS("db.example.com"),
1182-
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1183-
)
1184-
1185-
d := setupDialer(t, setupConfig{
1186-
skipServer: true,
1187-
testInstance: inst,
1188-
reqs: []*mock.Request{
1189-
mock.InstanceGetSuccess(inst, 2),
1190-
mock.CreateEphemeralSuccess(inst, 2),
1178+
func TestDialerRefreshesAfterRotateCACerts(t *testing.T) {
1179+
tcs := []struct {
1180+
desc string
1181+
rotateClientCA bool
1182+
wantErrorOnDial bool
1183+
wantErrorOnRead bool
1184+
useLazyRefresh bool
1185+
}{{
1186+
desc: "Rotating Client CA causes error on read, then refresh",
1187+
rotateClientCA: true,
1188+
wantErrorOnRead: true,
1189+
},
1190+
{
1191+
desc: "Rotating all CAs causes error on dial, then refresh",
1192+
wantErrorOnDial: true,
1193+
wantErrorOnRead: false,
11911194
},
1192-
dialerOptions: []Option{
1193-
WithTokenSource(mock.EmptyTokenSource{}),
1194-
WithDebugLogger(&dialerTestLogger{t: t}),
1195-
//WithLazyRefresh(),
1196-
// Note: this succeeds with lazy refresh, but fails with lazy.
1197-
// because dialer.ForceRefresh does not block connections while the
1198-
// refresh is in progress.
1195+
{
1196+
desc: "Rotating Client CA with lazy refresh causes error on read",
1197+
rotateClientCA: true,
1198+
wantErrorOnRead: true,
1199+
useLazyRefresh: true,
1200+
},
1201+
{
1202+
desc: "Rotating all CAs with lazy refresh causes error on dial",
1203+
wantErrorOnDial: true,
1204+
useLazyRefresh: true,
11991205
},
1200-
})
1201-
cancel1 := mock.StartServerProxy(t, inst)
1202-
t.Log("First attempt...")
1203-
testSuccessfulDial(
1204-
context.Background(), t, d,
1205-
"my-project:my-region:my-instance",
1206-
)
1207-
t.Log("First attempt OK. Resetting client cert.")
1208-
1209-
// Close the server
1210-
cancel1()
1211-
1212-
mock.RotateClientCA(inst)
1213-
1214-
// Start the server with new certificates
1215-
cancel2 := mock.StartServerProxy(t, inst)
1216-
defer cancel2()
1217-
1218-
// Dial a second time. We expect no error on dial, but TLS error on read.
1219-
t.Log("Second attempt should fail...")
1220-
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1221-
if err != nil {
1222-
t.Fatal("Should be no certificate error after, got ", err)
12231206
}
1207+
for _, tc := range tcs {
1208+
t.Run(tc.desc, func(t *testing.T) {
1209+
inst := mock.NewFakeCSQLInstanceWithSan(
1210+
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1211+
mock.WithDNS("db.example.com"),
1212+
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1213+
)
12241214

1225-
// Expect an error on read. This should trigger the dialer to refresh.
1226-
_, err = io.ReadAll(conn)
1227-
if err != nil {
1228-
t.Log("Got error on read as expected.", err)
1229-
} else {
1230-
t.Fatal("Want read error, got no error")
1231-
}
1232-
t.Log("Second attempt done")
1215+
opts := []Option{
1216+
WithTokenSource(mock.EmptyTokenSource{}),
1217+
WithDebugLogger(&dialerTestLogger{t: t}),
1218+
}
1219+
if tc.useLazyRefresh {
1220+
opts = append(opts, WithLazyRefresh())
1221+
}
12331222

1234-
// Dial again. This should complete after the refresh.
1235-
t.Log("Third attempt...")
1236-
testSuccessfulDial(
1237-
context.Background(), t, d,
1238-
"my-project:my-region:my-instance",
1239-
)
1240-
t.Log("Third attempt OK.")
1241-
}
1223+
d := setupDialer(t, setupConfig{
1224+
skipServer: true,
1225+
testInstance: inst,
1226+
reqs: []*mock.Request{
1227+
mock.InstanceGetSuccess(inst, 2),
1228+
mock.CreateEphemeralSuccess(inst, 2),
1229+
},
1230+
dialerOptions: opts,
1231+
})
1232+
cancel1 := mock.StartServerProxy(t, inst)
1233+
t.Log("First attempt...")
1234+
testSuccessfulDial(
1235+
context.Background(), t, d,
1236+
"my-project:my-region:my-instance",
1237+
)
1238+
t.Log("First attempt OK. Resetting client cert.")
12421239

1243-
func TestDialerRefreshesAfterRotateServerCA(t *testing.T) {
1244-
inst := mock.NewFakeCSQLInstanceWithSan(
1245-
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1246-
mock.WithDNS("db.example.com"),
1247-
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1248-
)
1240+
// Close the server
1241+
cancel1()
12491242

1250-
d := setupDialer(t, setupConfig{
1251-
skipServer: true,
1252-
testInstance: inst,
1253-
reqs: []*mock.Request{
1254-
mock.InstanceGetSuccess(inst, 2),
1255-
mock.CreateEphemeralSuccess(inst, 2),
1256-
},
1257-
dialerOptions: []Option{
1258-
WithTokenSource(mock.EmptyTokenSource{}),
1259-
WithDebugLogger(&dialerTestLogger{t: t}),
1260-
WithLazyRefresh(),
1261-
// Note: this succeeds with lazy refresh, but fails with lazy.
1262-
// because dialer.ForceRefresh does not block connections while the
1263-
// refresh is in progress.
1264-
},
1265-
})
1266-
cancel1 := mock.StartServerProxy(t, inst)
1267-
t.Log("First attempt...")
1268-
testSuccessfulDial(
1269-
context.Background(), t, d,
1270-
"my-project:my-region:my-instance",
1271-
)
1272-
t.Log("First attempt OK. Resetting client cert.")
1243+
if tc.rotateClientCA {
1244+
mock.RotateClientCA(inst)
1245+
} else {
1246+
mock.RotateCA(inst)
1247+
}
12731248

1274-
// Close the server
1275-
cancel1()
1249+
// Start the server with new certificates
1250+
cancel2 := mock.StartServerProxy(t, inst)
1251+
defer cancel2()
12761252

1277-
mock.RotateCA(inst)
1253+
// Dial a second time.
1254+
t.Log("Second attempt should fail...")
1255+
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1256+
if err != nil {
1257+
if tc.wantErrorOnDial {
1258+
t.Logf("got error on dial as expected: %v", err)
1259+
} else {
1260+
t.Fatalf("want no dial error, got: %v", err)
1261+
}
1262+
} else if tc.wantErrorOnDial {
1263+
t.Fatal("want dial error, got no error")
1264+
}
12781265

1279-
// Start the server with new certificates
1280-
cancel2 := mock.StartServerProxy(t, inst)
1281-
defer cancel2()
1266+
// If no error expected on dial, then attempt to read.
1267+
if !tc.wantErrorOnDial {
1268+
_, err = io.ReadAll(conn)
1269+
if err != nil {
1270+
if tc.wantErrorOnRead {
1271+
t.Logf("got error on read as expected: %v", err)
1272+
} else {
1273+
t.Fatalf("want no read error, got: %v", err)
1274+
}
1275+
} else if tc.wantErrorOnRead {
1276+
t.Fatal("want read error, got no error")
1277+
}
1278+
}
1279+
t.Log("Second attempt done")
12821280

1283-
// Dial a second time. We expect no error on dial, but TLS error on read.
1284-
t.Log("Second attempt should fail...")
1285-
_, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1286-
if err != nil {
1287-
t.Log("Got error on dial as expected.", err)
1288-
} else {
1289-
t.Fatal("Want dial error, got no error")
1281+
// Dial again. This should complete after the refresh.
1282+
t.Log("Third attempt...")
1283+
testSuccessfulDial(
1284+
context.Background(), t, d,
1285+
"my-project:my-region:my-instance",
1286+
)
1287+
t.Log("Third attempt OK.")
1288+
})
12901289
}
12911290

1292-
// Dial again. This should occur after the refresh has completed.
1293-
t.Log("Third attempt...")
1294-
testSuccessfulDial(
1295-
context.Background(), t, d,
1296-
"my-project:my-region:my-instance",
1297-
)
1298-
t.Log("Third attempt OK.")
12991291
}
13001292

13011293
type dialerTestLogger struct {

0 commit comments

Comments
 (0)