Skip to content
3 changes: 2 additions & 1 deletion cns/configuration/cns_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@
"AZRSettings": {
"PopulateHomeAzCacheRetryIntervalSecs": 60
},
"MinTLSVersion": "TLS 1.2"
"MinTLSVersion": "TLS 1.2",
"MtlsClientCertSubjectName": ""
}
1 change: 1 addition & 0 deletions cns/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type CNSConfig struct {
WireserverIP string
GRPCSettings GRPCSettings
MinTLSVersion string
MtlsClientCertSubjectName string
}

type TelemetrySettings struct {
Expand Down
9 changes: 6 additions & 3 deletions cns/configuration/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "localhost",
Port: 8080,
},
MinTLSVersion: "TLS 1.2",
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "",
},
},
{
Expand Down Expand Up @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "192.168.1.1",
Port: 9090,
},
MinTLSVersion: "TLS 1.3",
MinTLSVersion: "TLS 1.3",
MtlsClientCertSubjectName: "example.com",
},
want: CNSConfig{
ChannelMode: "Other",
Expand Down Expand Up @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
IPAddress: "192.168.1.1",
Port: 9090,
},
MinTLSVersion: "TLS 1.3",
MinTLSVersion: "TLS 1.3",
MtlsClientCertSubjectName: "example.com",
},
},
}
Expand Down
36 changes: 35 additions & 1 deletion cns/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,35 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
}

// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name.
func verifyPeerCertificate(verifiedChains [][]*x509.Certificate, clientSubjectName string) error {
// no client subject name provided, skip verification
if clientSubjectName == "" {
return nil
}

if len(verifiedChains) == 0 || len(verifiedChains[0]) == 0 {
return errors.New("no client certificate provided during mTLS")
}

clientCert := verifiedChains[0][0]
// Match DNS names (case-insensitive)
dnsName := clientCert.DNSNames
for _, dns := range dnsName {
if strings.EqualFold(dns, clientSubjectName) {
return nil
}
}

// If SANs didn't match, fall back to Common Name (CN) match.
clientCN := clientCert.Subject.CommonName
if clientCert.Subject.CommonName != "" && strings.EqualFold(clientCN, clientSubjectName) {
return nil
}
return errors.Errorf("Failed to verify client certificate subject name during mTLS, clientSubjectName: %s, client cert SANs: %+v, CN: %s",
clientSubjectName, dnsName, clientCN)
}

func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
if err != nil {
Expand Down Expand Up @@ -202,8 +231,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName)
}
}

logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)

return tlsConfig, nil
Expand Down Expand Up @@ -254,6 +285,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyPeerCertificate(verifiedChains, tlsSettings.MtlsClientCertSubjectName)
}
}

logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
Expand Down
1 change: 1 addition & 0 deletions cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ func main() {
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
UseMTLS: cnsconfig.UseMTLS,
MinTLSVersion: cnsconfig.MinTLSVersion,
MtlsClientCertSubjectName: cnsconfig.MtlsClientCertSubjectName,
}
}

Expand Down
118 changes: 76 additions & 42 deletions cns/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,57 +133,91 @@ func TestNewService(t *testing.T) {
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
testCertFilePath := createTestCertificate(t)

config.TLSSettings = serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
TLSSetting := serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "example.com",
}

svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)
TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{
TLSPort: "10092",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "random.com",
}

svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")
TLSSettingWithClientCertCN := serverTLS.TlsSettings{
TLSPort: "10093",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
MinTLSVersion: "TLS 1.2",
MtlsClientCertSubjectName: "foo.com", // Common Name from test certificate
}

err = svc.Initialize(config)
t.Cleanup(func() {
svc.Uninitialize()
})
require.NoError(t, err)
runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) {
config.TLSSettings = tlsSettings
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)

err = svc.StartListener(config)
require.NoError(t, err)
svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")

mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)
err = svc.Initialize(config)
require.NoError(t, err)

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}
err = svc.StartListener(config)
require.NoError(t, err)

// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)

// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}

tlsURL := "https://localhost:" + tlsSettings.TLSPort
// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
})
if handshakeFailureExpected {
require.Error(t, err)
require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS")

} else {
require.NoError(t, err)
}

// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)

// Cleanup
svc.Uninitialize()
}
runMutualTLSTest(TLSSetting, false)
runMutualTLSTest(TLSSettingWithClientCertCN, false)
runMutualTLSTest(TLSSettingWithDisallowedClientSN, true)
})
}

Expand Down
1 change: 1 addition & 0 deletions server/tls/tlscertificate_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type TlsSettings struct {
KeyVaultCertificateRefreshInterval time.Duration
UseMTLS bool
MinTLSVersion string
MtlsClientCertSubjectName string
}

func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {
Expand Down
Loading