Skip to content

Commit 6a61292

Browse files
committed
fix: Refresh client cert when it is rejected by the server.
1 parent 2da4570 commit 6a61292

File tree

5 files changed

+146
-62
lines changed

5 files changed

+146
-62
lines changed

dialer.go

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ type keyGenerator struct {
8585
// - generate an RSA key lazily when it's requested, or
8686
// - (default) immediately generate an RSA key as part of the initializer.
8787
func newKeyGenerator(
88-
k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error),
88+
k *rsa.PrivateKey, lazy bool, genFunc func() (*rsa.PrivateKey, error),
8989
) (*keyGenerator, error) {
9090
g := &keyGenerator{genFunc: genFunc}
9191
switch {
@@ -418,9 +418,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
418418
tlsConn := tls.Client(conn, ci.TLSConfig())
419419
err = tlsConn.HandshakeContext(ctx)
420420
if err != nil {
421+
// TLS handshake errors are fatal and require a refresh. Remove the instance
422+
// from the cache so that future calls to Dial() will block until the
423+
// certificate is refreshed successfully.
421424
d.logger.Debugf(ctx, "[%v] TLS handshake failed: %v", cn.String(), err)
422-
// refresh the instance info in case it caused the handshake failure
423-
c.ForceRefresh()
425+
d.removeCached(ctx, cn, c, err)
424426
_ = tlsConn.Close() // best effort close attempt
425427
return nil, errtype.NewDialError("handshake failed", cn.String(), err)
426428
}
@@ -435,7 +437,21 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
435437
iConn := newInstrumentedConn(tlsConn, func() {
436438
n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) // c.openConnsCount = c.openConnsCount - 1
437439
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
438-
}, d.dialerID, cn.String())
440+
},
441+
func(err error) {
442+
// ignore EOF
443+
if err == io.EOF {
444+
return
445+
}
446+
d.logger.Debugf(ctx, "[%v] IO Error on Read or Write: %v", cn.String(), err)
447+
if d.isTLSError(err) {
448+
// TLS handshake errors are fatal. Remove the instance from the cache
449+
// so that future calls to Dial() will block until the certificate
450+
// is refreshed successfully.
451+
d.removeCached(ctx, cn, c, err)
452+
}
453+
},
454+
d.dialerID, cn.String())
439455

440456
// If this connection was opened using a Domain Name, then store it for later
441457
// in case it needs to be forcibly closed.
@@ -446,12 +462,18 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
446462
}
447463
return iConn, nil
448464
}
465+
func (d *Dialer) isTLSError(err error) bool {
466+
if nErr, ok := err.(net.Error); ok {
467+
return !nErr.Timeout()
468+
}
469+
return false
470+
}
449471

450472
// removeCached stops all background refreshes and deletes the connection
451473
// info cache from the map of caches.
452474
func (d *Dialer) removeCached(
453-
ctx context.Context,
454-
i instance.ConnName, c connectionInfoCache, err error,
475+
ctx context.Context,
476+
i instance.ConnName, c connectionInfoCache, err error,
455477
) {
456478
d.logger.Debugf(
457479
ctx,
@@ -461,16 +483,19 @@ func (d *Dialer) removeCached(
461483
)
462484
d.lock.Lock()
463485
defer d.lock.Unlock()
464-
c.Close()
465-
delete(d.cache, createKey(i))
486+
key := createKey(i)
487+
if cachedC, ok := d.cache[key]; ok && cachedC == c {
488+
c.Close()
489+
delete(d.cache, key)
490+
}
466491
}
467492

468493
// validClientCert checks that the ephemeral client certificate retrieved from
469494
// the cache is unexpired. The time comparisons strip the monotonic clock value
470495
// to ensure an accurate result, even after laptop sleep.
471496
func validClientCert(
472-
ctx context.Context, cn instance.ConnName,
473-
l debug.ContextLogger, expiration time.Time,
497+
ctx context.Context, cn instance.ConnName,
498+
l debug.ContextLogger, expiration time.Time,
474499
) bool {
475500
// Use UTC() to strip monotonic clock value to guard against inaccurate
476501
// comparisons, especially after laptop sleep.
@@ -535,10 +560,11 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err
535560

536561
// newInstrumentedConn initializes an instrumentedConn that on closing will
537562
// decrement the number of open connects and record the result.
538-
func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName string) *instrumentedConn {
563+
func newInstrumentedConn(conn net.Conn, closeFunc func(), errFunc func(error), dialerID, connName string) *instrumentedConn {
539564
return &instrumentedConn{
540565
Conn: conn,
541566
closeFunc: closeFunc,
567+
errFunc: errFunc,
542568
dialerID: dialerID,
543569
connName: connName,
544570
}
@@ -549,6 +575,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
549575
type instrumentedConn struct {
550576
net.Conn
551577
closeFunc func()
578+
errFunc func(error)
552579
mu sync.RWMutex
553580
closed bool
554581
dialerID string
@@ -561,6 +588,8 @@ func (i *instrumentedConn) Read(b []byte) (int, error) {
561588
bytesRead, err := i.Conn.Read(b)
562589
if err == nil {
563590
go trace.RecordBytesReceived(context.Background(), int64(bytesRead), i.connName, i.dialerID)
591+
} else {
592+
i.errFunc(err)
564593
}
565594
return bytesRead, err
566595
}
@@ -571,6 +600,8 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
571600
bytesWritten, err := i.Conn.Write(b)
572601
if err == nil {
573602
go trace.RecordBytesSent(context.Background(), int64(bytesWritten), i.connName, i.dialerID)
603+
} else {
604+
i.errFunc(err)
574605
}
575606
return bytesWritten, err
576607
}
@@ -634,7 +665,7 @@ func createKey(cn instance.ConnName) cacheKey {
634665
// connection info Cache in a threadsafe way. It will create a new cache,
635666
// modify the existing one, or leave it unchanged as needed.
636667
func (d *Dialer) connectionInfoCache(
637-
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
668+
ctx context.Context, cn instance.ConnName, useIAMAuthN *bool,
638669
) (*monitoredCache, error) {
639670
k := createKey(cn)
640671

dialer_test.go

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
11751175
}
11761176
}
11771177

1178-
func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
1178+
func TestDialerRefreshesAfterRotateClientCA(t *testing.T) {
11791179
inst := mock.NewFakeCSQLInstanceWithSan(
11801180
"my-project", "my-region", "my-instance", []string{"db.example.com"},
11811181
mock.WithDNS("db.example.com"),
@@ -1210,14 +1210,13 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
12101210
cancel1()
12111211

12121212
mock.RotateClientCA(inst)
1213-
time.Sleep(2 * time.Second)
12141213

1215-
// Recreate the instance, which generates new server certificates
12161214
// Start the server with new certificates
12171215
cancel2 := mock.StartServerProxy(t, inst)
12181216
defer cancel2()
12191217

12201218
// Dial a second time. We expect no error on dial, but TLS error on read.
1219+
t.Log("Second attempt should fail...")
12211220
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
12221221
if err != nil {
12231222
t.Fatal("Should be no certificate error after, got ", err)
@@ -1230,8 +1229,66 @@ func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
12301229
} else {
12311230
t.Fatal("Want read error, got no error")
12321231
}
1232+
t.Log("Second attempt done")
1233+
1234+
// Dial again. This should complete after the refresh.
1235+
t.Log("Third attempt...")
1236+
testSuccessfulDial(
1237+
context.Background(), t, d,
1238+
"my-project:my-region:my-instance",
1239+
)
1240+
t.Log("Third attempt OK.")
1241+
}
1242+
1243+
func TestDialerRefreshesAfterRotateServerCA(t *testing.T) {
1244+
inst := mock.NewFakeCSQLInstanceWithSan(
1245+
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1246+
mock.WithDNS("db.example.com"),
1247+
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1248+
)
1249+
1250+
d := setupDialer(t, setupConfig{
1251+
skipServer: true,
1252+
testInstance: inst,
1253+
reqs: []*mock.Request{
1254+
mock.InstanceGetSuccess(inst, 2),
1255+
mock.CreateEphemeralSuccess(inst, 2),
1256+
},
1257+
dialerOptions: []Option{
1258+
WithTokenSource(mock.EmptyTokenSource{}),
1259+
WithDebugLogger(&dialerTestLogger{t: t}),
1260+
WithLazyRefresh(),
1261+
// Note: this succeeds with lazy refresh, but fails with lazy.
1262+
// because dialer.ForceRefresh does not block connections while the
1263+
// refresh is in progress.
1264+
},
1265+
})
1266+
cancel1 := mock.StartServerProxy(t, inst)
1267+
t.Log("First attempt...")
1268+
testSuccessfulDial(
1269+
context.Background(), t, d,
1270+
"my-project:my-region:my-instance",
1271+
)
1272+
t.Log("First attempt OK. Resetting client cert.")
1273+
1274+
// Close the server
1275+
cancel1()
1276+
1277+
mock.RotateCA(inst)
1278+
1279+
// Start the server with new certificates
1280+
cancel2 := mock.StartServerProxy(t, inst)
1281+
defer cancel2()
1282+
1283+
// Dial a second time. We expect no error on dial, but TLS error on read.
1284+
t.Log("Second attempt should fail...")
1285+
_, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1286+
if err != nil {
1287+
t.Log("Got error on dial as expected.", err)
1288+
} else {
1289+
t.Fatal("Want dial error, got no error")
1290+
}
12331291

1234-
time.Sleep(2 * time.Second)
12351292
// Dial again. This should occur after the refresh has completed.
12361293
t.Log("Third attempt...")
12371294
testSuccessfulDial(

internal/mock/certs.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,6 @@ func (ct *TLSCertificates) GenerateServerCertWithCn(cn string) *x509.Certificate
239239
time.Now().Add(1*time.Hour), nil)
240240
}
241241

242-
func (ct *TLSCertificates) GenerateServerCertWithSan(sans []string) *x509.Certificate {
243-
return mustBuildSignedCertificate(
244-
false,
245-
name(""),
246-
ct.serverKeyPair,
247-
intermediateCaSubject,
248-
ct.serverIntermediateCaKeyPair,
249-
time.Now().Add(1*time.Hour), sans)
250-
}
251-
252242
// ServerChain creates a []tls.Certificate for use with a TLS server socket.
253243
// serverCAMode controls whether this returns a legacy or CAS server
254244
// certificate.
@@ -269,17 +259,21 @@ func (ct *TLSCertificates) ServerChain(serverCAMode string) []tls.Certificate {
269259
}}
270260

271261
}
262+
263+
// ClientCAPool returns a CertPool with the client CA.
272264
func (ct *TLSCertificates) ClientCAPool() *x509.CertPool {
273265
clientCa := x509.NewCertPool()
274266
clientCa.AddCert(ct.clientSigningCACertificate)
275267
return clientCa
276268
}
277269

270+
// RotateClientCA rotates only client CA certificates and keys.
278271
func (ct *TLSCertificates) RotateClientCA() {
279272
ct.clientSigningCaKeyPair = mustGenerateKey()
280273
ct.clientSigningCACertificate = mustBuildRootCertificate(signingCaSubject, ct.clientSigningCaKeyPair)
281274
}
282275

276+
// RotateCA rotates all certificates and keys.
283277
func (ct *TLSCertificates) RotateCA() {
284278
oneYear := time.Now().AddDate(1, 0, 0)
285279
ct.serverCaKeyPair = mustGenerateKey()

internal/mock/cloudsql.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,17 @@ func (f FakeCSQLInstance) String() string {
7070
return fmt.Sprintf("%v:%v:%v", f.project, f.region, f.name)
7171
}
7272

73-
func (f FakeCSQLInstance) ServerCACert() ([]byte, error) {
73+
func (f FakeCSQLInstance) serverCACert() ([]byte, error) {
7474
if f.signer != nil {
7575
return f.signer(f.Cert, f.Key)
76-
} else {
77-
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
78-
// legacy server mode, return only the server cert
79-
return toPEMFormat(f.certs.serverCert)
80-
}
81-
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
8276
}
77+
78+
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
79+
// legacy server mode, return only the server cert
80+
return toPEMFormat(f.certs.serverCert)
81+
}
82+
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
83+
8384
}
8485

8586
// ClientCert creates an ephemeral client certificate signed with the Cloud SQL
@@ -272,9 +273,8 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
272273
if opErr, ok := aErr.(net.Error); ok {
273274
if opErr.Timeout() {
274275
continue
275-
} else {
276-
return
277276
}
277+
return
278278
}
279279
if aErr == io.EOF {
280280
return
@@ -297,10 +297,12 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
297297
}
298298
}
299299

300+
// RotateCA rotates all CA certificates in the instance.
300301
func RotateCA(inst FakeCSQLInstance) {
301302
inst.certs.RotateCA()
302303
}
303304

305+
// RotateClientCA rotates all only the client CA certificates in the instance.
304306
func RotateClientCA(inst FakeCSQLInstance) {
305307
inst.certs.RotateClientCA()
306308
}

internal/mock/sqladmin.go

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,38 +102,38 @@ func (r *Request) matches(hR *http.Request) bool {
102102
//
103103
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/instances/get
104104
func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request {
105-
var ips []*sqladmin.IpMapping
106-
for ipType, addr := range i.ipAddrs {
107-
if ipType == "PUBLIC" {
108-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
109-
continue
110-
}
111-
if ipType == "PRIVATE" {
112-
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
113-
}
114-
}
115-
116-
certBytes, err := i.ServerCACert()
117-
if err != nil {
118-
panic(err)
119-
}
120-
121-
db := &sqladmin.ConnectSettings{
122-
BackendType: i.backendType,
123-
DatabaseVersion: i.dbVersion,
124-
DnsName: i.DNSName,
125-
IpAddresses: ips,
126-
Region: i.region,
127-
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
128-
PscEnabled: i.pscEnabled,
129-
ServerCaMode: i.serverCAMode,
130-
}
131-
132105
r := &Request{
133106
reqMethod: http.MethodGet,
134107
reqPath: fmt.Sprintf("/sql/v1beta4/projects/%s/instances/%s/connectSettings", i.project, i.name),
135108
reqCt: ct,
136109
handle: func(resp http.ResponseWriter, _ *http.Request) {
110+
var ips []*sqladmin.IpMapping
111+
for ipType, addr := range i.ipAddrs {
112+
if ipType == "PUBLIC" {
113+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIMARY"})
114+
continue
115+
}
116+
if ipType == "PRIVATE" {
117+
ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"})
118+
}
119+
}
120+
121+
certBytes, err := i.serverCACert()
122+
if err != nil {
123+
panic(err)
124+
}
125+
126+
db := &sqladmin.ConnectSettings{
127+
BackendType: i.backendType,
128+
DatabaseVersion: i.dbVersion,
129+
DnsName: i.DNSName,
130+
IpAddresses: ips,
131+
Region: i.region,
132+
ServerCaCert: &sqladmin.SslCert{Cert: string(certBytes)},
133+
PscEnabled: i.pscEnabled,
134+
ServerCaMode: i.serverCAMode,
135+
}
136+
137137
b, err := db.MarshalJSON()
138138
if err != nil {
139139
http.Error(resp, err.Error(), http.StatusInternalServerError)

0 commit comments

Comments
 (0)