Skip to content

Commit de225e4

Browse files
authored
feat: add mTLS to CNS (#2751)
* feat: add UseMTLS config * feat: add mTLS auth for CNS * test: add testdata for mTLS tests * chore: add logs on TLS config retrieval * lint: in tests * refactor: use CNS logger, not ACN logger * refactor: add guards to mtlsRootCAsFromCertificate and unit tests * lint: fix lint errors * test: include HTTP listener tests for when TLS/mTLS is enabled * chore: add log for stopping the TLS listener * test: add test helper to create certificates for testing instead of using hardcoded pem file * test: assert non-TLS service has no TLSSettings * test: refactor TestMtlsRootCAsFromCertificate to table-based tests * refactor: pull listener addresses from listener and remove redundant struct field for tls address
1 parent 13f7037 commit de225e4

File tree

9 files changed

+387
-1
lines changed

9 files changed

+387
-1
lines changed

cns/configuration/cns_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"TLSPort": "10091",
2121
"TLSSubjectName": "",
2222
"UseHTTPS": false,
23+
"UseMTLS": false,
2324
"WireserverIP": "168.63.129.16",
2425
"KeyVaultSettings": {
2526
"URL": "",

cns/configuration/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ type CNSConfig struct {
5555
TLSSubjectName string
5656
TelemetrySettings TelemetrySettings
5757
UseHTTPS bool
58+
UseMTLS bool
5859
WatchPods bool `json:"-"`
5960
WireserverIP string
6061
}

cns/configuration/configuration_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ func TestReadConfigFromFile(t *testing.T) {
8787
PopulateHomeAzCacheRetryIntervalSecs: 60,
8888
},
8989
UseHTTPS: true,
90+
UseMTLS: true,
9091
WireserverIP: "168.63.129.16",
9192
},
9293
wantErr: false,

cns/configuration/testdata/good.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"TelemetryBatchSizeBytes": 16384
3131
},
3232
"UseHTTPS": true,
33+
"UseMTLS": true,
3334
"WireserverIP": "168.63.129.16",
3435
"AZRSettings": {
3536
"PopulateHomeAzCacheRetryIntervalSecs": 60

cns/service.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package cns
66
import (
77
"context"
88
"crypto/tls"
9+
"crypto/x509"
910
"fmt"
1011
"net"
1112
"net/http"
@@ -190,6 +191,18 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
190191
},
191192
}
192193

194+
if tlsSettings.UseMTLS {
195+
rootCAs, err := mtlsRootCAsFromCertificate(&tlsCert)
196+
if err != nil {
197+
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
198+
}
199+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
200+
tlsConfig.ClientCAs = rootCAs
201+
tlsConfig.RootCAs = rootCAs
202+
}
203+
204+
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
205+
193206
return tlsConfig, nil
194207
}
195208

@@ -224,9 +237,51 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
224237
},
225238
}
226239

240+
if tlsSettings.UseMTLS {
241+
tlsCert := cr.GetCertificate()
242+
rootCAs, err := mtlsRootCAsFromCertificate(tlsCert)
243+
if err != nil {
244+
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
245+
}
246+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
247+
tlsConfig.ClientCAs = rootCAs
248+
tlsConfig.RootCAs = rootCAs
249+
}
250+
251+
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
252+
227253
return &tlsConfig, nil
228254
}
229255

256+
// Given a TLS cert, return the root CAs
257+
func mtlsRootCAsFromCertificate(tlsCert *tls.Certificate) (*x509.CertPool, error) {
258+
switch {
259+
case tlsCert == nil || len(tlsCert.Certificate) == 0:
260+
return nil, errors.New("no certificate provided")
261+
case len(tlsCert.Certificate) == 1:
262+
certs := x509.NewCertPool()
263+
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
264+
if err != nil {
265+
return nil, errors.Wrap(err, "parsing self signed cert")
266+
}
267+
certs.AddCert(cert)
268+
269+
return certs, nil
270+
default:
271+
certs := x509.NewCertPool()
272+
// given a fullchain cert, we skip leaf cert at index 0 because
273+
// we only want intermediate and root certs in the cert pool for mTLS
274+
for _, certBytes := range tlsCert.Certificate[1:] {
275+
cert, err := x509.ParseCertificate(certBytes)
276+
if err != nil {
277+
return nil, errors.Wrap(err, "parsing root certs")
278+
}
279+
certs.AddCert(cert)
280+
}
281+
return certs, nil
282+
}
283+
}
284+
230285
func (service *Service) StartListener(config *common.ServiceConfig) error {
231286
log.Debugf("[Azure CNS] Going to start listener: %+v", config)
232287

cns/service/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ func main() {
786786
KeyVaultCertificateName: cnsconfig.KeyVaultSettings.CertificateName,
787787
MSIResourceID: cnsconfig.MSISettings.ResourceID,
788788
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
789+
UseMTLS: cnsconfig.UseMTLS,
789790
}
790791
}
791792

0 commit comments

Comments
 (0)