Skip to content

Commit aa5996a

Browse files
committed
wip: Failing test - error on read after dial.
1 parent 3a7c320 commit aa5996a

File tree

3 files changed

+173
-89
lines changed

3 files changed

+173
-89
lines changed

dialer_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,3 +1174,77 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
11741174
t.Fatal("want error containing `tls: failed to verify certificate`. Got: ", err)
11751175
}
11761176
}
1177+
1178+
func TestDialerRefreshesAfterClientCertificateError(t *testing.T) {
1179+
inst := mock.NewFakeCSQLInstanceWithSan(
1180+
"my-project", "my-region", "my-instance", []string{"db.example.com"},
1181+
mock.WithDNS("db.example.com"),
1182+
mock.WithServerCAMode("GOOGLE_MANAGED_CAS_CA"),
1183+
)
1184+
1185+
d := setupDialer(t, setupConfig{
1186+
skipServer: true,
1187+
testInstance: inst,
1188+
reqs: []*mock.Request{
1189+
mock.InstanceGetSuccess(inst, 2),
1190+
mock.CreateEphemeralSuccess(inst, 2),
1191+
},
1192+
dialerOptions: []Option{
1193+
WithTokenSource(mock.EmptyTokenSource{}),
1194+
WithDebugLogger(&dialerTestLogger{t: t}),
1195+
//WithLazyRefresh(),
1196+
// Note: this succeeds with lazy refresh, but fails with lazy.
1197+
// because dialer.ForceRefresh does not block connections while the
1198+
// refresh is in progress.
1199+
},
1200+
})
1201+
cancel1 := mock.StartServerProxy(t, inst)
1202+
t.Log("First attempt...")
1203+
testSuccessfulDial(
1204+
context.Background(), t, d,
1205+
"my-project:my-region:my-instance",
1206+
)
1207+
t.Log("First attempt OK. Resetting client cert.")
1208+
1209+
// Close the server
1210+
cancel1()
1211+
1212+
mock.RotateClientCA(inst)
1213+
time.Sleep(2 * time.Second)
1214+
1215+
// Recreate the instance, which generates new server certificates
1216+
// Start the server with new certificates
1217+
cancel2 := mock.StartServerProxy(t, inst)
1218+
defer cancel2()
1219+
1220+
// Dial a second time. We expect no error on dial, but TLS error on read.
1221+
conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance")
1222+
if err != nil {
1223+
t.Fatal("Should be no certificate error after, got ", err)
1224+
}
1225+
1226+
// Expect an error on read. This should trigger the dialer to refresh.
1227+
_, err = io.ReadAll(conn)
1228+
if err != nil {
1229+
t.Log("Got error on read as expected.", err)
1230+
} else {
1231+
t.Fatal("Want read error, got no error")
1232+
}
1233+
1234+
time.Sleep(2 * time.Second)
1235+
// Dial again. This should occur after the refresh has completed.
1236+
t.Log("Third attempt...")
1237+
testSuccessfulDial(
1238+
context.Background(), t, d,
1239+
"my-project:my-region:my-instance",
1240+
)
1241+
t.Log("Third attempt OK.")
1242+
}
1243+
1244+
type dialerTestLogger struct {
1245+
t *testing.T
1246+
}
1247+
1248+
func (l *dialerTestLogger) Debugf(f string, args ...interface{}) {
1249+
l.t.Logf(f, args...)
1250+
}

internal/mock/certs.go

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@ var instanceWithCnSubject = name("myProject:myInstance")
4848
// From the cloud-sql-jdbc-socket-factory project:
4949
// core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java
5050
type TLSCertificates struct {
51-
ServerCaKeyPair *rsa.PrivateKey
52-
SigningCaKeyPair *rsa.PrivateKey
53-
ServerKeyPair *rsa.PrivateKey
54-
ServerIntermediateCaKeyPair *rsa.PrivateKey
55-
ServerSigningCaKeyPair *rsa.PrivateKey
56-
ClientKeyPair *rsa.PrivateKey
57-
DomainServerKeyPair *rsa.PrivateKey
58-
59-
ServerCaCert *x509.Certificate
60-
SigningCaCert *x509.Certificate
61-
ServerCert *x509.Certificate
62-
ServerIntermediateCaCert *x509.Certificate
63-
CasServerCertificate *x509.Certificate
64-
CasServerCertificateChain []*x509.Certificate
65-
DomainServerCertificate *x509.Certificate
66-
clientCertExpires time.Time
51+
clientCertExpires time.Time
52+
projectName string
53+
instanceName string
54+
sans []string
55+
56+
serverCaKeyPair *rsa.PrivateKey
57+
serverIntermediateCaKeyPair *rsa.PrivateKey
58+
clientSigningCaKeyPair *rsa.PrivateKey
59+
60+
serverCaCert *x509.Certificate
61+
serverIntermediateCaCert *x509.Certificate
62+
clientSigningCACertificate *x509.Certificate
63+
64+
serverKeyPair *rsa.PrivateKey
65+
serverCert *x509.Certificate
66+
casServerCertificate *x509.Certificate
6767
}
6868

6969
func mustGenerateKey() *rsa.PrivateKey {
@@ -76,55 +76,13 @@ func mustGenerateKey() *rsa.PrivateKey {
7676

7777
// newTLSCertificates creates a new instance of the TLSCertificates.
7878
func newTLSCertificates(projectName, instanceName string, sans []string, clientCertExpires time.Time) *TLSCertificates {
79-
oneYear := time.Now().AddDate(1, 0, 0)
80-
8179
c := &TLSCertificates{
82-
clientCertExpires: clientCertExpires,
83-
ServerCaKeyPair: mustGenerateKey(),
84-
SigningCaKeyPair: mustGenerateKey(),
85-
ServerKeyPair: mustGenerateKey(),
86-
ServerIntermediateCaKeyPair: mustGenerateKey(),
87-
ServerSigningCaKeyPair: mustGenerateKey(),
88-
ClientKeyPair: mustGenerateKey(),
89-
DomainServerKeyPair: mustGenerateKey(),
80+
clientCertExpires: clientCertExpires,
81+
projectName: projectName,
82+
instanceName: instanceName,
83+
sans: sans,
9084
}
91-
92-
c.ServerCaCert = mustBuildRootCertificate(serverCaSubject, c.ServerCaKeyPair)
93-
c.SigningCaCert = mustBuildRootCertificate(signingCaSubject, c.SigningCaKeyPair)
94-
95-
c.ServerCert = mustBuildSignedCertificate(
96-
false,
97-
name(projectName+":"+instanceName),
98-
c.ServerKeyPair,
99-
serverCaSubject,
100-
c.ServerCaKeyPair,
101-
oneYear,
102-
nil)
103-
104-
c.ServerIntermediateCaCert =
105-
mustBuildSignedCertificate(
106-
true,
107-
intermediateCaSubject,
108-
c.ServerIntermediateCaKeyPair,
109-
serverCaSubject,
110-
c.ServerCaKeyPair,
111-
oneYear,
112-
nil)
113-
114-
c.CasServerCertificate =
115-
mustBuildSignedCertificate(
116-
false,
117-
name(""),
118-
c.ServerKeyPair,
119-
intermediateCaSubject,
120-
c.ServerIntermediateCaKeyPair,
121-
oneYear,
122-
sans)
123-
124-
c.CasServerCertificateChain =
125-
[]*x509.Certificate{
126-
c.CasServerCertificate, c.ServerIntermediateCaCert, c.ServerCaCert}
127-
85+
c.RotateCA()
12886
return c
12987
}
13088

@@ -252,7 +210,7 @@ func (ct *TLSCertificates) signWithClientKey(clientKey *rsa.PublicKey) ([]byte,
252210
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
253211
BasicConstraintsValid: true,
254212
}
255-
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ct.SigningCaCert, clientKey, ct.SigningCaKeyPair)
213+
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ct.clientSigningCACertificate, clientKey, ct.clientSigningCaKeyPair)
256214
if err != nil {
257215
return nil, err
258216
}
@@ -273,9 +231,9 @@ func (ct *TLSCertificates) generateServerCertWithCn(cn string) *x509.Certificate
273231
return mustBuildSignedCertificate(
274232
false,
275233
name(cn),
276-
ct.ServerKeyPair,
234+
ct.serverKeyPair,
277235
serverCaSubject,
278-
ct.ServerCaKeyPair,
236+
ct.serverCaKeyPair,
279237
time.Now().Add(1*time.Hour), nil)
280238
}
281239

@@ -286,16 +244,67 @@ func (ct *TLSCertificates) serverChain(serverCAMode string) []tls.Certificate {
286244
// if this server is running in legacy mode
287245
if serverCAMode == "" || serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
288246
return []tls.Certificate{{
289-
Certificate: [][]byte{ct.ServerCert.Raw, ct.ServerCaCert.Raw},
290-
PrivateKey: ct.ServerKeyPair,
291-
Leaf: ct.ServerCert,
247+
Certificate: [][]byte{ct.serverCert.Raw, ct.serverCaCert.Raw},
248+
PrivateKey: ct.serverKeyPair,
249+
Leaf: ct.serverCert,
292250
}}
293251
}
294252

295253
return []tls.Certificate{{
296-
Certificate: [][]byte{ct.CasServerCertificate.Raw, ct.ServerIntermediateCaCert.Raw, ct.ServerCaCert.Raw},
297-
PrivateKey: ct.ServerKeyPair,
298-
Leaf: ct.CasServerCertificate,
254+
Certificate: [][]byte{ct.casServerCertificate.Raw, ct.serverIntermediateCaCert.Raw, ct.serverCaCert.Raw},
255+
PrivateKey: ct.serverKeyPair,
256+
Leaf: ct.casServerCertificate,
299257
}}
300258

301259
}
260+
func (ct *TLSCertificates) ClientCAPool() *x509.CertPool {
261+
clientCa := x509.NewCertPool()
262+
clientCa.AddCert(ct.clientSigningCACertificate)
263+
return clientCa
264+
}
265+
266+
func (ct *TLSCertificates) RotateClientCA() {
267+
ct.clientSigningCaKeyPair = mustGenerateKey()
268+
ct.clientSigningCACertificate = mustBuildRootCertificate(signingCaSubject, ct.clientSigningCaKeyPair)
269+
}
270+
271+
func (ct *TLSCertificates) RotateCA() {
272+
oneYear := time.Now().AddDate(1, 0, 0)
273+
ct.serverCaKeyPair = mustGenerateKey()
274+
ct.clientSigningCaKeyPair = mustGenerateKey()
275+
ct.serverKeyPair = mustGenerateKey()
276+
ct.serverIntermediateCaKeyPair = mustGenerateKey()
277+
278+
ct.serverCaCert = mustBuildRootCertificate(serverCaSubject, ct.serverCaKeyPair)
279+
280+
ct.serverIntermediateCaCert =
281+
mustBuildSignedCertificate(
282+
true,
283+
intermediateCaSubject,
284+
ct.serverIntermediateCaKeyPair,
285+
serverCaSubject,
286+
ct.serverCaKeyPair,
287+
oneYear,
288+
nil)
289+
290+
ct.casServerCertificate =
291+
mustBuildSignedCertificate(
292+
false,
293+
name(""),
294+
ct.serverKeyPair,
295+
intermediateCaSubject,
296+
ct.serverIntermediateCaKeyPair,
297+
oneYear,
298+
ct.sans)
299+
300+
ct.serverCert = mustBuildSignedCertificate(
301+
false,
302+
name(ct.projectName+":"+ct.instanceName),
303+
ct.serverKeyPair,
304+
serverCaSubject,
305+
ct.serverCaKeyPair,
306+
oneYear,
307+
nil)
308+
309+
ct.RotateClientCA()
310+
}

internal/mock/cloudsql.go

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ type FakeCSQLInstance struct {
6161
// Cert is the server's certificate
6262
Cert *x509.Certificate
6363
// certs holds all of the certificates for this instance
64-
certs *TLSCertificates
65-
clientCa *x509.CertPool
64+
certs *TLSCertificates
6665
}
6766

6867
// String returns the instance connection name for the
@@ -75,12 +74,13 @@ func (f FakeCSQLInstance) String() string {
7574
func (f FakeCSQLInstance) serverCACert() ([]byte, error) {
7675
if f.signer != nil {
7776
return f.signer(f.Cert, f.Key)
77+
} else {
78+
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
79+
// legacy server mode, return only the server cert
80+
return toPEMFormat(f.certs.serverCert)
81+
}
82+
return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert)
7883
}
79-
if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" {
80-
// clasic server mode, return cert with CA
81-
return toPEMFormat(f.certs.ServerCert)
82-
}
83-
return toPEMFormat(f.certs.CasServerCertificateChain...)
8484
}
8585

8686
// ClientCert creates an ephemeral client certificate signed with the Cloud SQL
@@ -219,17 +219,10 @@ func NewFakeCSQLInstanceWithSan(project, region, name string, sanDNSNames []stri
219219

220220
certs := newTLSCertificates(project, name, sanDNSNames, f.certExpiry)
221221

222-
f.Key = certs.ServerKeyPair
223-
f.Cert = certs.ServerCert
222+
f.Key = certs.serverKeyPair
223+
f.Cert = certs.serverCert
224224
f.certs = certs
225225

226-
clientCa, err := x509.SystemCertPool()
227-
clientCa.AddCert(certs.SigningCaCert)
228-
if err != nil {
229-
panic(err)
230-
}
231-
f.clientCa = clientCa
232-
233226
return f
234227
}
235228

@@ -263,7 +256,7 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
263256

264257
ln, err := tls.Listen("tcp", ":3307", &tls.Config{
265258
Certificates: i.certs.serverChain(i.serverCAMode),
266-
ClientCAs: i.clientCa,
259+
ClientCAs: i.certs.ClientCAPool(),
267260
ClientAuth: tls.RequireAndVerifyClientCert,
268261
})
269262
if err != nil {
@@ -303,3 +296,11 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() {
303296
_ = ln.Close()
304297
}
305298
}
299+
300+
func RotateCA(inst FakeCSQLInstance) {
301+
inst.certs.RotateCA()
302+
}
303+
304+
func RotateClientCA(inst FakeCSQLInstance) {
305+
inst.certs.RotateClientCA()
306+
}

0 commit comments

Comments
 (0)