Skip to content

Commit 3ced6ae

Browse files
authored
security, config: add a config cert-allowed-cn to the HTTP API (#736)
1 parent 3fba696 commit 3ced6ae

File tree

7 files changed

+172
-138
lines changed

7 files changed

+172
-138
lines changed

conf/proxy.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ graceful-close-conn-timeout = 15
107107
[security.server-http-tls]
108108
# proxy HTTP port will use this
109109
# auto-certs = true
110+
# cert-allowed-cn = ["tiproxy", "tidb", "test-client", "prometheus"]
110111

111112
# require-backend-tls = false
112113

lib/config/security.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
package config
55

66
type TLSConfig struct {
7-
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
8-
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
9-
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
10-
MinTLSVersion string `yaml:"min-tls-version,omitempty" toml:"min-tls-version,omitempty" json:"min-tls-version,omitempty"`
11-
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
12-
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
13-
AutoExpireDuration string `yaml:"autocert-expire-duration,omitempty" toml:"autocert-expire-duration,omitempty" json:"autocert-expire-duration,omitempty"`
14-
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
7+
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
8+
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
9+
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
10+
MinTLSVersion string `yaml:"min-tls-version,omitempty" toml:"min-tls-version,omitempty" json:"min-tls-version,omitempty"`
11+
CertAllowedCN []string `yaml:"cert-allowed-cn,omitempty" toml:"cert-allowed-cn,omitempty" json:"cert-allowed-cn,omitempty"`
12+
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
13+
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
14+
AutoExpireDuration string `yaml:"autocert-expire-duration,omitempty" toml:"autocert-expire-duration,omitempty" json:"autocert-expire-duration,omitempty"`
15+
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
1516
}
1617

1718
func (c TLSConfig) HasCert() bool {

lib/util/security/cert.go

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ package security
66
import (
77
"crypto/tls"
88
"crypto/x509"
9-
"encoding/pem"
109
"os"
10+
"strings"
1111
"sync/atomic"
1212
"time"
1313

@@ -70,7 +70,7 @@ func (ci *CertInfo) getClientCert(*tls.CertificateRequestInfo) (*tls.Certificate
7070
return cert, nil
7171
}
7272

73-
func (ci *CertInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
73+
func (ci *CertInfo) verifyCA(rawCerts [][]byte) error {
7474
if len(rawCerts) == 0 {
7575
return nil
7676
}
@@ -108,26 +108,23 @@ func (ci *CertInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certifi
108108
return err
109109
}
110110

111-
func (ci *CertInfo) loadCA(pemCerts []byte) (*x509.CertPool, error) {
112-
pool := x509.NewCertPool()
113-
for len(pemCerts) > 0 {
114-
var block *pem.Block
115-
block, pemCerts = pem.Decode(pemCerts)
116-
if block == nil {
117-
break
118-
}
119-
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
120-
continue
121-
}
122-
123-
certBytes := block.Bytes
124-
cert, err := x509.ParseCertificate(certBytes)
125-
if err != nil {
126-
continue
111+
func verifyCommonName(allowedCN []string, verifiedChains [][]*x509.Certificate) error {
112+
if len(allowedCN) == 0 {
113+
return nil
114+
}
115+
checkCN := make(map[string]struct{})
116+
for _, cn := range allowedCN {
117+
cn = strings.TrimSpace(cn)
118+
checkCN[cn] = struct{}{}
119+
}
120+
for _, chain := range verifiedChains {
121+
if len(chain) != 0 {
122+
if _, match := checkCN[chain[0].Subject.CommonName]; match {
123+
return nil
124+
}
127125
}
128-
pool.AddCert(cert)
129126
}
130-
return pool, nil
127+
return errors.Errorf("peer certificate authentication failed. The Common Name from the peer certificate was not found in the configuration cert-allowed-cn with value: %v", allowedCN)
131128
}
132129

133130
func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
@@ -144,10 +141,15 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
144141
}
145142

146143
tcfg := &tls.Config{
147-
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
148-
GetCertificate: ci.getCert,
149-
GetClientCertificate: ci.getClientCert,
150-
VerifyPeerCertificate: ci.verifyPeerCertificate,
144+
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
145+
GetCertificate: ci.getCert,
146+
GetClientCertificate: ci.getClientCert,
147+
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
148+
if err := ci.verifyCA(rawCerts); err != nil {
149+
return err
150+
}
151+
return verifyCommonName(cfg.CertAllowedCN, verifiedChains)
152+
},
151153
}
152154

153155
var certPEM, keyPEM []byte
@@ -160,7 +162,7 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
160162
dur = DefaultCertExpiration
161163
}
162164
ci.autoCertExp.Store(now.Add(DefaultCertExpiration - recreateAutoCertAdvance).Unix())
163-
certPEM, keyPEM, _, err = createTempTLS(cfg.RSAKeySize, dur)
165+
certPEM, keyPEM, _, err = createTempTLS(cfg.RSAKeySize, dur, "")
164166
if err != nil {
165167
return nil, err
166168
}
@@ -193,14 +195,20 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
193195
if err != nil {
194196
return nil, err
195197
}
196-
197-
cas, err := ci.loadCA(caPEM)
198-
if err != nil {
199-
return nil, errors.WithStack(err)
198+
certPool := x509.NewCertPool()
199+
if !certPool.AppendCertsFromPEM(caPEM) {
200+
return nil, errors.New("failed to append ca certs")
200201
}
201-
ci.ca.Store(cas)
202+
ci.ca.Store(certPool)
202203

203-
if cfg.SkipCA {
204+
// RequireAndVerifyClientCert requires ClientCAs to verify client certificates.
205+
// But the problem is, the ClientCAs in the returned tls.Config can't be updated after reload,
206+
// which results in connection failure after CA rotation and cert-allowed-cn is set.
207+
tcfg.ClientCAs = certPool
208+
209+
if len(cfg.CertAllowedCN) > 0 {
210+
tcfg.ClientAuth = tls.RequireAndVerifyClientCert
211+
} else if cfg.SkipCA {
204212
tcfg.ClientAuth = tls.RequestClientCert
205213
} else {
206214
tcfg.ClientAuth = tls.RequireAnyClientCert
@@ -213,7 +221,7 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
213221
lg = lg.With(zap.String("tls", "client"), zap.Any("cfg", ci.cfg.Load()))
214222
cfg := ci.cfg.Load()
215223
if cfg.AutoCerts {
216-
lg.Info("specified auto-certs in a client tls config, ignored")
224+
lg.Warn("specified auto-certs in a client tls config, ignored")
217225
}
218226

219227
if !cfg.HasCA() {
@@ -229,22 +237,25 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
229237
}
230238

231239
tcfg := &tls.Config{
232-
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
233-
GetCertificate: ci.getCert,
234-
GetClientCertificate: ci.getClientCert,
235-
InsecureSkipVerify: true,
236-
VerifyPeerCertificate: ci.verifyPeerCertificate,
240+
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
241+
GetCertificate: ci.getCert,
242+
GetClientCertificate: ci.getClientCert,
243+
InsecureSkipVerify: true,
244+
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
245+
return ci.verifyCA(rawCerts)
246+
},
237247
}
238248

239-
certBytes, err := os.ReadFile(cfg.CA)
249+
caPEM, err := os.ReadFile(cfg.CA)
240250
if err != nil {
241-
return nil, errors.WithStack(err)
251+
return nil, err
242252
}
243-
cas, err := ci.loadCA(certBytes)
244-
if err != nil {
245-
return nil, errors.WithStack(err)
253+
certPool := x509.NewCertPool()
254+
if !certPool.AppendCertsFromPEM(caPEM) {
255+
return nil, errors.New("failed to append ca certs")
246256
}
247-
ci.ca.Store(cas)
257+
ci.ca.Store(certPool)
258+
tcfg.RootCAs = certPool
248259

249260
if !cfg.HasCert() {
250261
lg.Info("no certificates, server may reject the connection")

0 commit comments

Comments
 (0)