diff --git a/internal/cloudsql/instance_test.go b/internal/cloudsql/instance_test.go index c2f7561f..6e437768 100644 --- a/internal/cloudsql/instance_test.go +++ b/internal/cloudsql/instance_test.go @@ -150,19 +150,6 @@ func TestConnectionInfoTLSConfig(t *testing.T) { t.Fatal(err) } - // Now self sign the server's cert - // TODO: this also should return structured data and handle the PEM - // encoding elsewhere - certBytes, err := mock.SelfSign(i.Cert, i.Key) - if err != nil { - t.Fatal(err) - } - b, _ = pem.Decode(certBytes) - serverCACert, err := x509.ParseCertificate(b.Bytes) - if err != nil { - t.Fatal(err) - } - // Assemble a connection info with the raw and parsed client cert // and the self-signed server certificate ci := ConnectionInfo{ @@ -172,7 +159,7 @@ func TestConnectionInfoTLSConfig(t *testing.T) { PrivateKey: RSAKey, Leaf: clientCert, }, - ServerCACert: []*x509.Certificate{serverCACert}, + ServerCACert: []*x509.Certificate{i.Cert}, DBVersion: "doesn't matter here", Expiration: clientCert.NotAfter, } @@ -198,7 +185,7 @@ func TestConnectionInfoTLSConfig(t *testing.T) { } verifyPeerCert := got.VerifyPeerCertificate - err = verifyPeerCert([][]byte{serverCACert.Raw}, nil) + err = verifyPeerCert([][]byte{i.Cert.Raw}, nil) if err != nil { t.Fatalf("expected to verify peer cert, got error: %v", err) } diff --git a/internal/mock/certs.go b/internal/mock/certs.go new file mode 100644 index 00000000..fa62be61 --- /dev/null +++ b/internal/mock/certs.go @@ -0,0 +1,310 @@ +// Copyright 2025 Google LLC + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// https://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/binary" + "encoding/pem" + "math/big" + "time" +) + +func name(cn string) pkix.Name { + return pkix.Name{ + Country: []string{"US"}, + Organization: []string{"Google\\, Inc"}, + CommonName: cn, + } +} + +// "C=US,O=Google\\, Inc,CN=Google Cloud SQL Root CA" +var serverCaSubject = name("Google Cloud SQL Root CA") +var intermediateCaSubject = name("Google Cloud SQL Intermediate CA") +var signingCaSubject = name("Google Cloud SQL Signing CA foo:baz") +var instanceWithCnSubject = name("myProject:myInstance") + +// TLSCertificates generates an accurate reproduction of the TLS certificates +// used by Cloud SQL. This was translated to Go from the Java connector. +// +// From the cloud-sql-jdbc-socket-factory project: +// core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java +type TLSCertificates struct { + clientCertExpires time.Time + projectName string + instanceName string + sans []string + + serverCaKey *rsa.PrivateKey + serverIntermediateCaKey *rsa.PrivateKey + clientSigningCaKey *rsa.PrivateKey + + serverCaCert *x509.Certificate + serverIntermediateCaCert *x509.Certificate + clientSigningCACertificate *x509.Certificate + + serverKey *rsa.PrivateKey + serverCert *x509.Certificate + casServerCertificate *x509.Certificate +} + +func mustGenerateKey() *rsa.PrivateKey { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + return key +} + +// newTLSCertificates creates a new instance of the TLSCertificates. +func newTLSCertificates(projectName, instanceName string, sans []string, clientCertExpires time.Time) *TLSCertificates { + c := &TLSCertificates{ + clientCertExpires: clientCertExpires, + projectName: projectName, + instanceName: instanceName, + sans: sans, + } + c.rotateCA() + return c +} + +// generateSKI Generate public key id. Certificates need to include +// the key id to make the certificate chain work. +func generateSKI(pub *rsa.PublicKey) []byte { + bs := make([]byte, 8) + binary.LittleEndian.PutUint64(bs, uint64(pub.E)) + + hasher := sha1.New() + hasher.Write(bs) + if pub.N != nil { + hasher.Write(pub.N.Bytes()) + } + ski := hasher.Sum(nil) + + return ski +} + +// mustBuildRootCertificate produces a self-signed certificate. +// or panics - use only for testing. +func mustBuildRootCertificate(subject pkix.Name, k *rsa.PrivateKey) *x509.Certificate { + + sn, err := rand.Int(rand.Reader, big.NewInt(1000)) + if err != nil { + panic(err) + } + + cert := &x509.Certificate{ + SerialNumber: sn, + SubjectKeyId: generateSKI(&k.PublicKey), + Subject: subject, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + certDerBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, &k.PublicKey, k) + c, err := x509.ParseCertificate(certDerBytes) + if err != nil { + panic(err) + } + return c +} + +// mustBuildSignedCertificate produces a certificate for Subject that is signed +// by the issuer. +func mustBuildSignedCertificate( + isCa bool, + subject pkix.Name, + subjectPublicKey *rsa.PrivateKey, + certificateIssuer pkix.Name, + issuerPrivateKey *rsa.PrivateKey, + notAfter time.Time, + subjectAlternativeNames []string) *x509.Certificate { + + sn, err := rand.Int(rand.Reader, big.NewInt(1000)) + if err != nil { + panic(err) + } + + cert := &x509.Certificate{ + SerialNumber: sn, + Subject: subject, + SubjectKeyId: generateSKI(&subjectPublicKey.PublicKey), + AuthorityKeyId: generateSKI(&issuerPrivateKey.PublicKey), + Issuer: certificateIssuer, + NotBefore: time.Now(), + NotAfter: notAfter, + IsCA: isCa, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + DNSNames: subjectAlternativeNames, + } + + certDerBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, &subjectPublicKey.PublicKey, issuerPrivateKey) + c, err := x509.ParseCertificate(certDerBytes) + if err != nil { + panic(err) + } + return c + +} + +// toPEMFormat Converts an array of certificates to PEM format. +func toPEMFormat(certs ...*x509.Certificate) ([]byte, error) { + certPEM := new(bytes.Buffer) + + for _, cert := range certs { + err := pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err != nil { + return nil, err + } + } + + return certPEM.Bytes(), nil +} + +// signWithClientKey produces a PEM encoded certificate client certificate +// containing the clientKey public key, signed by the client CA certificate. +func (ct *TLSCertificates) signWithClientKey(clientKey *rsa.PublicKey) ([]byte, error) { + notAfter := ct.clientCertExpires + if ct.clientCertExpires.IsZero() { + notAfter = time.Now().Add(1 * time.Hour) + } + + // Create a signed cert from the client's public key. + cert := &x509.Certificate{ // TODO: Validate this format vs API + SerialNumber: &big.Int{}, + Subject: pkix.Name{ + Country: []string{"US"}, + Organization: []string{"Google, Inc"}, + CommonName: "Google Cloud SQL Client", + }, + NotBefore: time.Now(), + NotAfter: notAfter, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + certBytes, err := x509.CreateCertificate(rand.Reader, cert, ct.clientSigningCACertificate, clientKey, ct.clientSigningCaKey) + if err != nil { + return nil, err + } + certPEM := new(bytes.Buffer) + err = pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + if err != nil { + return nil, err + } + return certPEM.Bytes(), nil +} + +// generateServerCertWithCn generates a server certificate for legacy +// GOOGLE_MANAGED_INTERNAL_CA mode where the instance name is in the CN. +func (ct *TLSCertificates) generateServerCertWithCn(cn string) *x509.Certificate { + return mustBuildSignedCertificate( + false, + name(cn), + ct.serverKey, + serverCaSubject, + ct.serverCaKey, + time.Now().Add(1*time.Hour), nil) +} + +// serverChain creates a []tls.Certificate for use with a TLS server socket. +// serverCAMode controls whether this returns a legacy or CAS server +// certificate. +func (ct *TLSCertificates) serverChain(serverCAMode string) []tls.Certificate { + // if this server is running in legacy mode + if serverCAMode == "" || serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" { + return []tls.Certificate{{ + Certificate: [][]byte{ct.serverCert.Raw, ct.serverCaCert.Raw}, + PrivateKey: ct.serverKey, + Leaf: ct.serverCert, + }} + } + + return []tls.Certificate{{ + Certificate: [][]byte{ct.casServerCertificate.Raw, ct.serverIntermediateCaCert.Raw, ct.serverCaCert.Raw}, + PrivateKey: ct.serverKey, + Leaf: ct.casServerCertificate, + }} + +} +func (ct *TLSCertificates) clientCAPool() *x509.CertPool { + clientCa := x509.NewCertPool() + clientCa.AddCert(ct.clientSigningCACertificate) + return clientCa +} + +func (ct *TLSCertificates) rotateClientCA() { + ct.clientSigningCaKey = mustGenerateKey() + ct.clientSigningCACertificate = mustBuildRootCertificate(signingCaSubject, ct.clientSigningCaKey) +} + +func (ct *TLSCertificates) rotateCA() { + oneYear := time.Now().AddDate(1, 0, 0) + ct.serverCaKey = mustGenerateKey() + ct.clientSigningCaKey = mustGenerateKey() + ct.serverKey = mustGenerateKey() + ct.serverIntermediateCaKey = mustGenerateKey() + + ct.serverCaCert = mustBuildRootCertificate(serverCaSubject, ct.serverCaKey) + + ct.serverIntermediateCaCert = + mustBuildSignedCertificate( + true, + intermediateCaSubject, + ct.serverIntermediateCaKey, + serverCaSubject, + ct.serverCaKey, + oneYear, + nil) + + ct.casServerCertificate = + mustBuildSignedCertificate( + false, + name(""), + ct.serverKey, + intermediateCaSubject, + ct.serverIntermediateCaKey, + oneYear, + ct.sans) + + ct.serverCert = mustBuildSignedCertificate( + false, + name(ct.projectName+":"+ct.instanceName), + ct.serverKey, + serverCaSubject, + ct.serverCaKey, + oneYear, + nil) + + ct.rotateClientCA() +} diff --git a/internal/mock/cloudsql.go b/internal/mock/cloudsql.go index f871cf28..896b35fa 100644 --- a/internal/mock/cloudsql.go +++ b/internal/mock/cloudsql.go @@ -21,10 +21,10 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" "encoding/pem" "fmt" - "math/big" + "io" + "net" "testing" "time" @@ -55,10 +55,13 @@ type FakeCSQLInstance struct { pscEnabled bool signer SignFunc clientSigner ClientSignFunc + certExpiry time.Time // Key is the server's private key Key *rsa.PrivateKey // Cert is the server's certificate Cert *x509.Certificate + // certs holds all of the certificates for this instance + certs *TLSCertificates } // String returns the instance connection name for the @@ -67,14 +70,29 @@ func (f FakeCSQLInstance) String() string { return fmt.Sprintf("%v:%v:%v", f.project, f.region, f.name) } -func (f FakeCSQLInstance) signedCert() ([]byte, error) { - return f.signer(f.Cert, f.Key) +// serverCACert returns the current server CA cert. +func (f FakeCSQLInstance) serverCACert() ([]byte, error) { + if f.signer != nil { + return f.signer(f.Cert, f.Key) + } + if f.serverCAMode == "" || f.serverCAMode == "GOOGLE_MANAGED_INTERNAL_CA" { + // legacy server mode, return only the server cert + return toPEMFormat(f.certs.serverCert) + } + return toPEMFormat(f.certs.casServerCertificate, f.certs.serverIntermediateCaCert, f.certs.serverCaCert) } // ClientCert creates an ephemeral client certificate signed with the Cloud SQL // instance's private key. The return value is PEM encoded. func (f FakeCSQLInstance) ClientCert(pubKey *rsa.PublicKey) ([]byte, error) { - return f.clientSigner(f.Cert, f.Key, pubKey) + if f.clientSigner != nil { + c, err := f.clientSigner(f.Cert, f.Key, pubKey) + if err != nil { + return c, err + } + return c, nil + } + return f.certs.signWithClientKey(pubKey) } // FakeCSQLInstanceOption is a function that configures a FakeCSQLInstance. @@ -111,7 +129,7 @@ func WithDNS(dns string) FakeCSQLInstanceOption { // WithCertExpiry sets the server certificate's expiration to t. func WithCertExpiry(t time.Time) FakeCSQLInstanceOption { return func(f *FakeCSQLInstance) { - f.Cert.NotAfter = t + f.certExpiry = t } } @@ -184,28 +202,26 @@ func NewFakeCSQLInstance(project, region, name string, opts ...FakeCSQLInstanceO // NewFakeCSQLInstanceWithSan returns a CloudSQLInst object for configuring // mocks, including SubjectAlternativeNames in the server certificate. func NewFakeCSQLInstanceWithSan(project, region, name string, sanDNSNames []string, opts ...FakeCSQLInstanceOption) FakeCSQLInstance { - // TODO: consider options for this? - key, cert, err := generateCerts(project, name, sanDNSNames) - if err != nil { - panic(err) - } f := FakeCSQLInstance{ - project: project, - region: region, - name: name, - ipAddrs: map[string]string{"PUBLIC": "0.0.0.0"}, - DNSName: "", - dbVersion: "POSTGRES_12", // default of no particular importance - backendType: "SECOND_GEN", - signer: SelfSign, - clientSigner: SignWithClientKey, - Key: key, - Cert: cert, + project: project, + region: region, + name: name, + ipAddrs: map[string]string{"PUBLIC": "0.0.0.0"}, + DNSName: "", + dbVersion: "POSTGRES_12", // default of no particular importance + backendType: "SECOND_GEN", } for _, o := range opts { o(&f) } + + certs := newTLSCertificates(project, name, sanDNSNames, f.certExpiry) + + f.Key = certs.serverKey + f.Cert = certs.serverCert + f.certs = certs + return f } @@ -226,114 +242,21 @@ func SelfSign(c *x509.Certificate, k *rsa.PrivateKey) ([]byte, error) { return certPEM.Bytes(), nil } -// SignWithClientKey produces a PEM encoded certificate signed by the parent -// certificate c using the server's private key and the client's public key. -func SignWithClientKey(c *x509.Certificate, k *rsa.PrivateKey, clientKey *rsa.PublicKey) ([]byte, error) { - // Create a signed cert from the client's public key. - cert := &x509.Certificate{ // TODO: Validate this format vs API - SerialNumber: &big.Int{}, - Subject: pkix.Name{ - Country: []string{"US"}, - Organization: []string{"Google, Inc"}, - CommonName: "Google Cloud SQL Client", - }, - NotBefore: time.Now(), - NotAfter: c.NotAfter, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - certBytes, err := x509.CreateCertificate(rand.Reader, cert, c, clientKey, k) - if err != nil { - return nil, err - } - certPEM := new(bytes.Buffer) - err = pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - if err != nil { - return nil, err - } - return certPEM.Bytes(), nil -} - // GenerateCertWithCommonName produces a certificate signed by the Fake Cloud // SQL instance's CA with the specified common name cn. func GenerateCertWithCommonName(i FakeCSQLInstance, cn string) []byte { - cert := &x509.Certificate{ - SerialNumber: &big.Int{}, - Subject: pkix.Name{ - CommonName: cn, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - IsCA: true, - } - signed, err := x509.CreateCertificate( - rand.Reader, cert, i.Cert, &i.Key.PublicKey, i.Key) - if err != nil { - panic(err) - } - return signed -} - -// generateCerts generates a private key, an X.509 certificate, and a TLS -// certificate for a particular fake Cloud SQL database instance. -func generateCerts(project, name string, dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - cert := &x509.Certificate{ - SerialNumber: &big.Int{}, - Subject: pkix.Name{ - CommonName: fmt.Sprintf("%s:%s", project, name), - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - DNSNames: dnsNames, - } - - return key, cert, nil + return i.certs.generateServerCertWithCn(cn).Raw } // StartServerProxy starts a fake server proxy and listens on the provided port // on all interfaces, configured with TLS as specified by the FakeCSQLInstance. // Callers should invoke the returned function to clean up all resources. func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() { - certBytes, err := x509.CreateCertificate( - rand.Reader, i.Cert, i.Cert, &i.Key.PublicKey, i.Key) - if err != nil { - t.Fatalf("failed to create certificate: %v", err) - } - caPEM := &bytes.Buffer{} - err = pem.Encode(caPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) - if err != nil { - t.Fatalf("pem.Encode: %v", err) - } - - caKeyPEM := &bytes.Buffer{} - err = pem.Encode(caKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(i.Key), - }) - if err != nil { - t.Fatalf("pem.Encode: %v", err) - } - - serverCert, err := tls.X509KeyPair(caPEM.Bytes(), caKeyPEM.Bytes()) - if err != nil { - t.Fatalf("failed to create X.509 Key Pair: %v", err) - } ln, err := tls.Listen("tcp", ":3307", &tls.Config{ - Certificates: []tls.Certificate{serverCert}, + Certificates: i.certs.serverChain(i.serverCAMode), + ClientCAs: i.certs.clientCAPool(), + ClientAuth: tls.RequireAndVerifyClientCert, }) if err != nil { t.Fatalf("failed to start listener: %v", err) @@ -345,11 +268,24 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() { case <-ctx.Done(): return default: - conn, err := ln.Accept() - if err != nil { + conn, aErr := ln.Accept() + if opErr, ok := aErr.(net.Error); ok { + if opErr.Timeout() { + continue + } return } - _, _ = conn.Write([]byte(i.name)) + if aErr == io.EOF { + return + } + if aErr != nil { + t.Logf("Fake server accept error: %v", aErr) + return + } + _, wErr := conn.Write([]byte(i.name)) + if wErr != nil { + t.Logf("Fake server write error: %v", wErr) + } _ = conn.Close() } } @@ -359,3 +295,13 @@ func StartServerProxy(t *testing.T, i FakeCSQLInstance) func() { _ = ln.Close() } } + +// RotateCA rotates all CA certificates and keys. +func RotateCA(inst FakeCSQLInstance) { + inst.certs.rotateCA() +} + +// RotateClientCA rotates only client CA certificates and keys. +func RotateClientCA(inst FakeCSQLInstance) { + inst.certs.rotateClientCA() +} diff --git a/internal/mock/sqladmin.go b/internal/mock/sqladmin.go index b25f042f..f8afdfd7 100644 --- a/internal/mock/sqladmin.go +++ b/internal/mock/sqladmin.go @@ -112,19 +112,12 @@ func InstanceGetSuccess(i FakeCSQLInstance, ct int) *Request { ips = append(ips, &sqladmin.IpMapping{IpAddress: addr, Type: "PRIVATE"}) } } - certBytes1, err := i.signedCert() + + certBytes, err := i.serverCACert() if err != nil { panic(err) } - certBytes := certBytes1 - if i.serverCAMode == "GOOGLE_MANAGED_CAS_CA" { - // CAS instances return two CAs in the trust chain. - certBytes2, err := i.signedCert() - if err != nil { - panic(err) - } - certBytes = append(certBytes, certBytes2...) - } + db := &sqladmin.ConnectSettings{ BackendType: i.backendType, DatabaseVersion: i.dbVersion,