Skip to content

Commit d669d38

Browse files
committed
filtering TLS connections based on the subject name from Caller
1 parent 73f6733 commit d669d38

File tree

7 files changed

+104
-47
lines changed

7 files changed

+104
-47
lines changed

cns/configuration/cns_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,6 @@
3535
"AZRSettings": {
3636
"PopulateHomeAzCacheRetryIntervalSecs": 60
3737
},
38-
"MinTLSVersion": "TLS 1.2"
38+
"MinTLSVersion": "TLS 1.2",
39+
"AllowedClientSubjectName": ""
3940
}

cns/configuration/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ type CNSConfig struct {
5959
WireserverIP string
6060
GRPCSettings GRPCSettings
6161
MinTLSVersion string
62+
AllowedClientSubjectName string
6263
}
6364

6465
type TelemetrySettings struct {

cns/configuration/configuration_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
222222
IPAddress: "localhost",
223223
Port: 8080,
224224
},
225-
MinTLSVersion: "TLS 1.2",
225+
MinTLSVersion: "TLS 1.2",
226+
AllowedClientSubjectName: "",
226227
},
227228
},
228229
{
@@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
253254
IPAddress: "192.168.1.1",
254255
Port: 9090,
255256
},
256-
MinTLSVersion: "TLS 1.3",
257+
MinTLSVersion: "TLS 1.3",
258+
AllowedClientSubjectName: "example.com",
257259
},
258260
want: CNSConfig{
259261
ChannelMode: "Other",
@@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) {
283285
IPAddress: "192.168.1.1",
284286
Port: 9090,
285287
},
286-
MinTLSVersion: "TLS 1.3",
288+
MinTLSVersion: "TLS 1.3",
289+
AllowedClientSubjectName: "example.com",
287290
},
288291
},
289292
}

cns/service.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,25 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.
156156
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
157157
}
158158

159+
// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name.
160+
func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error {
161+
// no client subject name provided, skip verification
162+
if clientSubjectName == "" {
163+
return nil
164+
}
165+
166+
cert, err := x509.ParseCertificate(rawCerts[0])
167+
if err != nil {
168+
return errors.Errorf("failed to parse certificate: %v", err)
169+
}
170+
171+
err = cert.VerifyHostname(clientSubjectName)
172+
if err != nil {
173+
return errors.Errorf("failed to verify client certificate hostname: %v", err)
174+
}
175+
return nil
176+
}
177+
159178
func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
160179
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
161180
if err != nil {
@@ -202,8 +221,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
202221
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
203222
tlsConfig.ClientCAs = rootCAs
204223
tlsConfig.RootCAs = rootCAs
224+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
225+
return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName)
226+
}
205227
}
206-
207228
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
208229

209230
return tlsConfig, nil
@@ -254,6 +275,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
254275
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
255276
tlsConfig.ClientCAs = rootCAs
256277
tlsConfig.RootCAs = rootCAs
278+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
279+
return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName)
280+
}
257281
}
258282

259283
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)

cns/service/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ func main() {
810810
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
811811
UseMTLS: cnsconfig.UseMTLS,
812812
MinTLSVersion: cnsconfig.MinTLSVersion,
813+
AllowedClientSubjectName: cnsconfig.AllowedClientSubjectName,
813814
}
814815
}
815816

cns/service_test.go

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"crypto/x509"
1313
"crypto/x509/pkix"
1414
"encoding/pem"
15+
"fmt"
1516
"math/big"
1617
"net/http"
1718
"os"
@@ -133,57 +134,82 @@ func TestNewService(t *testing.T) {
133134
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
134135
testCertFilePath := createTestCertificate(t)
135136

136-
config.TLSSettings = serverTLS.TlsSettings{
137-
TLSPort: "10091",
138-
TLSSubjectName: "localhost",
139-
TLSCertificatePath: testCertFilePath,
140-
UseMTLS: true,
141-
MinTLSVersion: "TLS 1.2",
137+
TLSSetting := serverTLS.TlsSettings{
138+
TLSPort: "10091",
139+
TLSSubjectName: "localhost",
140+
TLSCertificatePath: testCertFilePath,
141+
UseMTLS: true,
142+
MinTLSVersion: "TLS 1.2",
143+
AllowedClientSubjectName: "example.com",
142144
}
143145

144-
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
145-
require.NoError(t, err)
146-
require.IsType(t, &Service{}, svc)
146+
TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{
147+
TLSPort: "10092",
148+
TLSSubjectName: "localhost",
149+
TLSCertificatePath: testCertFilePath,
150+
UseMTLS: true,
151+
MinTLSVersion: "TLS 1.2",
152+
AllowedClientSubjectName: "random.com",
153+
}
147154

148-
svc.SetOption(acn.OptCnsURL, "")
149-
svc.SetOption(acn.OptCnsPort, "")
155+
runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) {
156+
config.TLSSettings = tlsSettings
157+
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
158+
require.NoError(t, err)
159+
require.IsType(t, &Service{}, svc)
150160

151-
err = svc.Initialize(config)
152-
t.Cleanup(func() {
153-
svc.Uninitialize()
154-
})
155-
require.NoError(t, err)
161+
svc.SetOption(acn.OptCnsURL, "")
162+
svc.SetOption(acn.OptCnsPort, "")
156163

157-
err = svc.StartListener(config)
158-
require.NoError(t, err)
164+
err = svc.Initialize(config)
165+
require.NoError(t, err)
159166

160-
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
161-
require.NoError(t, err)
167+
err = svc.StartListener(config)
168+
require.NoError(t, err)
162169

163-
client := &http.Client{
164-
Transport: &http.Transport{
165-
TLSClientConfig: mTLSConfig,
166-
},
167-
}
170+
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
171+
require.NoError(t, err)
168172

169-
// TLS listener
170-
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
171-
require.NoError(t, err)
172-
resp, err := client.Do(req)
173-
t.Cleanup(func() {
174-
resp.Body.Close()
175-
})
176-
require.NoError(t, err)
173+
client := &http.Client{
174+
Transport: &http.Transport{
175+
TLSClientConfig: mTLSConfig,
176+
},
177+
}
177178

178-
// HTTP listener
179-
httpClient := &http.Client{}
180-
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
181-
require.NoError(t, err)
182-
resp, err = httpClient.Do(req)
183-
t.Cleanup(func() {
184-
resp.Body.Close()
185-
})
186-
require.NoError(t, err)
179+
tlsUrl := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort)
180+
// TLS listener
181+
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsUrl, http.NoBody)
182+
require.NoError(t, err)
183+
resp, err := client.Do(req)
184+
t.Cleanup(func() {
185+
if resp != nil && resp.Body != nil {
186+
resp.Body.Close()
187+
}
188+
})
189+
if handshakeFailureExpected {
190+
require.Error(t, err)
191+
require.ErrorContains(t, err, "failed to verify client certificate hostname")
192+
193+
} else {
194+
require.NoError(t, err)
195+
}
196+
197+
// HTTP listener
198+
httpClient := &http.Client{}
199+
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
200+
require.NoError(t, err)
201+
resp, err = httpClient.Do(req)
202+
t.Cleanup(func() {
203+
resp.Body.Close()
204+
})
205+
require.NoError(t, err)
206+
207+
// Cleanup
208+
svc.Uninitialize()
209+
210+
}
211+
runMutualTLSTest(TLSSetting, false)
212+
runMutualTLSTest(TLSSettingWithDisallowedClientSN, true)
187213
})
188214
}
189215

server/tls/tlscertificate_retriever.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ type TlsSettings struct {
1515
KeyVaultCertificateRefreshInterval time.Duration
1616
UseMTLS bool
1717
MinTLSVersion string
18+
AllowedClientSubjectName string
1819
}
1920

2021
func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {

0 commit comments

Comments
 (0)