Skip to content

Commit 33e3b69

Browse files
committed
address comment
1 parent d3c6255 commit 33e3b69

File tree

2 files changed

+135
-75
lines changed

2 files changed

+135
-75
lines changed

cns/service.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,22 +167,40 @@ func verifyPeerCertificate(verifiedChains [][]*x509.Certificate, clientSubjectNa
167167
return errors.New("no client certificate provided during mTLS")
168168
}
169169

170+
// Get client leaf certificate
170171
clientCert := verifiedChains[0][0]
171172
// Match DNS names (case-insensitive)
172-
dnsName := clientCert.DNSNames
173-
for _, dns := range dnsName {
173+
dnsNames := clientCert.DNSNames
174+
for _, dns := range dnsNames {
174175
if strings.EqualFold(dns, clientSubjectName) {
175176
return nil
176177
}
177178
}
178179

179180
// If SANs didn't match, fall back to Common Name (CN) match.
180181
clientCN := clientCert.Subject.CommonName
181-
if clientCert.Subject.CommonName != "" && strings.EqualFold(clientCN, clientSubjectName) {
182+
if clientCN != "" && strings.EqualFold(clientCN, clientSubjectName) {
182183
return nil
183184
}
184-
return errors.Errorf("Failed to verify client certificate subject name during mTLS, clientSubjectName: %s, client cert SANs: %+v, CN: %s",
185-
clientSubjectName, dnsName, clientCN)
185+
186+
// maskHalf of the DNS names
187+
for i, dns := range dnsNames {
188+
dnsNames[i] = maskHalf(dns)
189+
}
190+
191+
return errors.Errorf("Failed to verify client certificate subject name during mTLS, clientSubjectName: %s, client cert SANs: %+v, clientCN: %s",
192+
clientSubjectName, dnsNames, maskHalf(clientCN))
193+
}
194+
195+
// maskHalf masks half of the input string with asterisks.
196+
func maskHalf(s string) string {
197+
n := len(s)
198+
if n == 0 {
199+
return s
200+
}
201+
202+
half := n / 2
203+
return s[:half] + strings.Repeat("*", n-half)
186204
}
187205

188206
func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {

cns/service_test.go

Lines changed: 112 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -133,91 +133,108 @@ func TestNewService(t *testing.T) {
133133
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
134134
testCertFilePath := createTestCertificate(t)
135135

136-
TLSSetting := serverTLS.TlsSettings{
137-
TLSPort: "10091",
138-
TLSSubjectName: "localhost",
139-
TLSCertificatePath: testCertFilePath,
140-
UseMTLS: true,
141-
MinTLSVersion: "TLS 1.2",
142-
MtlsClientCertSubjectName: "example.com",
143-
}
144-
145-
TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{
146-
TLSPort: "10092",
147-
TLSSubjectName: "localhost",
148-
TLSCertificatePath: testCertFilePath,
149-
UseMTLS: true,
150-
MinTLSVersion: "TLS 1.2",
151-
MtlsClientCertSubjectName: "random.com",
136+
cases := []struct {
137+
name string
138+
tlsSettings serverTLS.TlsSettings
139+
handshakeFailureExpected bool
140+
}{
141+
{
142+
name: "matching client SANs",
143+
tlsSettings: serverTLS.TlsSettings{
144+
TLSPort: "10091",
145+
TLSSubjectName: "localhost",
146+
TLSCertificatePath: testCertFilePath,
147+
UseMTLS: true,
148+
MinTLSVersion: "TLS 1.2",
149+
MtlsClientCertSubjectName: "example.com",
150+
},
151+
handshakeFailureExpected: false,
152+
},
153+
{
154+
name: "matching client cert CN",
155+
tlsSettings: serverTLS.TlsSettings{
156+
TLSPort: "10093",
157+
TLSSubjectName: "localhost",
158+
TLSCertificatePath: testCertFilePath,
159+
UseMTLS: true,
160+
MinTLSVersion: "TLS 1.2",
161+
MtlsClientCertSubjectName: "foo.com", // Common Name from test certificate
162+
},
163+
handshakeFailureExpected: false,
164+
},
165+
{
166+
name: "failing to match client SANs and CN",
167+
tlsSettings: serverTLS.TlsSettings{
168+
TLSPort: "10092",
169+
TLSSubjectName: "localhost",
170+
TLSCertificatePath: testCertFilePath,
171+
UseMTLS: true,
172+
MinTLSVersion: "TLS 1.2",
173+
MtlsClientCertSubjectName: "random.com",
174+
},
175+
handshakeFailureExpected: true,
176+
},
152177
}
153178

154-
TLSSettingWithClientCertCN := serverTLS.TlsSettings{
155-
TLSPort: "10093",
156-
TLSSubjectName: "localhost",
157-
TLSCertificatePath: testCertFilePath,
158-
UseMTLS: true,
159-
MinTLSVersion: "TLS 1.2",
160-
MtlsClientCertSubjectName: "foo.com", // Common Name from test certificate
161-
}
179+
for _, tc := range cases {
180+
t.Run(tc.name, func(t *testing.T) {
181+
config.TLSSettings = tc.tlsSettings
162182

163-
runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) {
164-
config.TLSSettings = tlsSettings
165-
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
166-
require.NoError(t, err)
167-
require.IsType(t, &Service{}, svc)
183+
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
184+
require.NoError(t, err)
185+
require.IsType(t, &Service{}, svc)
168186

169-
svc.SetOption(acn.OptCnsURL, "")
170-
svc.SetOption(acn.OptCnsPort, "")
187+
svc.SetOption(acn.OptCnsURL, "")
188+
svc.SetOption(acn.OptCnsPort, "")
171189

172-
err = svc.Initialize(config)
173-
require.NoError(t, err)
190+
err = svc.Initialize(config)
191+
require.NoError(t, err)
174192

175-
err = svc.StartListener(config)
176-
require.NoError(t, err)
193+
err = svc.StartListener(config)
194+
require.NoError(t, err)
177195

178-
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
179-
require.NoError(t, err)
196+
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
197+
require.NoError(t, err)
180198

181-
client := &http.Client{
182-
Transport: &http.Transport{
183-
TLSClientConfig: mTLSConfig,
184-
},
185-
}
199+
client := &http.Client{
200+
Transport: &http.Transport{
201+
TLSClientConfig: mTLSConfig,
202+
},
203+
}
186204

187-
tlsURL := "https://localhost:" + tlsSettings.TLSPort
188-
// TLS listener
189-
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody)
190-
require.NoError(t, err)
191-
resp, err := client.Do(req)
192-
t.Cleanup(func() {
193-
if resp != nil && resp.Body != nil {
194-
resp.Body.Close()
205+
tlsURL := "https://localhost:" + tc.tlsSettings.TLSPort
206+
// TLS listener
207+
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody)
208+
require.NoError(t, err)
209+
resp, err := client.Do(req)
210+
t.Cleanup(func() {
211+
if resp != nil && resp.Body != nil {
212+
resp.Body.Close()
213+
}
214+
})
215+
if tc.handshakeFailureExpected {
216+
require.Error(t, err)
217+
require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS")
218+
} else {
219+
require.NoError(t, err)
195220
}
196-
})
197-
if handshakeFailureExpected {
198-
require.Error(t, err)
199-
require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS")
200221

201-
} else {
222+
// HTTP listener
223+
httpClient := &http.Client{}
224+
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
225+
require.NoError(t, err)
226+
resp, err = httpClient.Do(req)
227+
t.Cleanup(func() {
228+
if resp != nil && resp.Body != nil {
229+
resp.Body.Close()
230+
}
231+
})
202232
require.NoError(t, err)
203-
}
204233

205-
// HTTP listener
206-
httpClient := &http.Client{}
207-
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
208-
require.NoError(t, err)
209-
resp, err = httpClient.Do(req)
210-
t.Cleanup(func() {
211-
resp.Body.Close()
234+
// Cleanup
235+
svc.Uninitialize()
212236
})
213-
require.NoError(t, err)
214-
215-
// Cleanup
216-
svc.Uninitialize()
217237
}
218-
runMutualTLSTest(TLSSetting, false)
219-
runMutualTLSTest(TLSSettingWithClientCertCN, false)
220-
runMutualTLSTest(TLSSettingWithDisallowedClientSN, true)
221238
})
222239
}
223240

@@ -389,3 +406,28 @@ func TestTLSVersionNumber(t *testing.T) {
389406
require.NoError(t, err)
390407
})
391408
}
409+
410+
func TestMaskHalf(t *testing.T) {
411+
tests := []struct {
412+
name string
413+
in string
414+
want string
415+
}{
416+
{"empty", "", ""},
417+
{"one char string", "e", "*"},
418+
{"two chars string", "ex", "e*"},
419+
{"three chars string", "exa", "e**"},
420+
{"four chars string", "exam", "ex**"},
421+
{"five chars string", "examp", "ex***"},
422+
{"long string", "example.com", "examp******"},
423+
}
424+
425+
for _, tc := range tests {
426+
t.Run(tc.name, func(t *testing.T) {
427+
got := maskHalf(tc.in)
428+
if got != tc.want {
429+
t.Fatalf("maskHalf(%s) = %s, want %s", tc.in, got, tc.want)
430+
}
431+
})
432+
}
433+
}

0 commit comments

Comments
 (0)