Skip to content

Commit 9bc2b98

Browse files
authored
Support certificate chains in SNI auth (Azure#18832)
1 parent 1e10271 commit 9bc2b98

File tree

5 files changed

+40
-65
lines changed

5 files changed

+40
-65
lines changed

sdk/azidentity/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
### Breaking Changes
88

99
### Bugs Fixed
10+
* `ClientCertificateCredential` sends only the leaf cert for SNI authentication
1011

1112
### Other Changes
1213

sdk/azidentity/azidentity_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package azidentity
88

99
import (
1010
"context"
11+
"crypto/x509"
1112
"errors"
1213
"fmt"
1314
"io"
@@ -122,7 +123,7 @@ const (
122123
testHost = "https://localhost"
123124
)
124125

125-
func validateJWTRequestContainsHeader(t *testing.T, headerName string) mock.ResponsePredicate {
126+
func validateX5C(t *testing.T, certs []*x509.Certificate) mock.ResponsePredicate {
126127
return func(req *http.Request) bool {
127128
body, err := io.ReadAll(req.Body)
128129
if err != nil {
@@ -135,8 +136,10 @@ func validateJWTRequestContainsHeader(t *testing.T, headerName string) mock.Resp
135136
if token == nil {
136137
t.Fatalf("Failed to parse the JWT token: %s.", assertion[1])
137138
}
138-
if _, ok := token.Header[headerName]; !ok {
139-
t.Fatalf("JWT did not contain the %s header", headerName)
139+
if v, ok := token.Header["x5c"].([]any); !ok {
140+
t.Fatal("missing x5c header")
141+
} else if actual := len(v); actual != len(certs) {
142+
t.Fatalf("expected %d certs, got %d", len(certs), actual)
140143
}
141144
return true
142145
}

sdk/azidentity/client_certificate_credential.go

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ package azidentity
99
import (
1010
"context"
1111
"crypto"
12-
"crypto/rsa"
13-
"crypto/sha1"
1412
"crypto/x509"
15-
"encoding/base64"
1613
"encoding/pem"
1714
"errors"
1815
"os"
@@ -46,10 +43,6 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
4643
if len(certs) == 0 {
4744
return nil, errors.New("at least one certificate is required")
4845
}
49-
pk, ok := key.(*rsa.PrivateKey)
50-
if !ok {
51-
return nil, errors.New("'key' must be an *rsa.PrivateKey")
52-
}
5346
if !validTenantID(tenantID) {
5447
return nil, errors.New(tenantIDValidationErr)
5548
}
@@ -60,11 +53,7 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
6053
if err != nil {
6154
return nil, err
6255
}
63-
cert, err := newCertContents(certs, pk, options.SendCertificateChain)
64-
if err != nil {
65-
return nil, err
66-
}
67-
cred := confidential.NewCredFromCert(cert.c, key) // TODO: NewCredFromCert should take a slice
56+
cred, err := confidential.NewCredFromCertChain(certs, key)
6857
if err != nil {
6958
return nil, err
7059
}
@@ -156,36 +145,6 @@ func ParseCertificates(certData []byte, password []byte) ([]*x509.Certificate, c
156145
return certs, pk, nil
157146
}
158147

159-
type certContents struct {
160-
c *x509.Certificate // the signing cert
161-
fp []byte // the signing cert's fingerprint, a SHA-1 digest
162-
pk *rsa.PrivateKey // the signing key
163-
x5c []string // concatenation of every provided cert, base64 encoded
164-
}
165-
166-
func newCertContents(certs []*x509.Certificate, key *rsa.PrivateKey, sendCertificateChain bool) (*certContents, error) {
167-
cc := certContents{pk: key}
168-
// need the the signing cert's fingerprint: identify that cert by matching its public key to the private key
169-
for _, cert := range certs {
170-
certKey, ok := cert.PublicKey.(*rsa.PublicKey)
171-
if ok && key.E == certKey.E && key.N.Cmp(certKey.N) == 0 {
172-
fp := sha1.Sum(cert.Raw)
173-
cc.fp = fp[:]
174-
cc.c = cert
175-
if sendCertificateChain {
176-
// signing cert must be first in x5c
177-
cc.x5c = append([]string{base64.StdEncoding.EncodeToString(cert.Raw)}, cc.x5c...)
178-
}
179-
} else if sendCertificateChain {
180-
cc.x5c = append(cc.x5c, base64.StdEncoding.EncodeToString(cert.Raw))
181-
}
182-
}
183-
if len(cc.fp) == 0 || cc.c == nil {
184-
return nil, errors.New("found no certificate matching 'key'")
185-
}
186-
return &cc, nil
187-
}
188-
189148
func loadPEMCert(certData []byte) ([]*pem.Block, error) {
190149
blocks := []*pem.Block{}
191150
for {

sdk/azidentity/client_certificate_credential_test.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ func newCertTest(name, certPath string, password string) certTest {
3939
var allCertTests = []certTest{
4040
newCertTest("pem", "testdata/certificate.pem", ""),
4141
newCertTest("pemB", "testdata/certificate_formatB.pem", ""),
42+
newCertTest("pemChain", "testdata/certificate-with-chain.pem", ""),
4243
newCertTest("pkcs12", "testdata/certificate.pfx", ""),
4344
newCertTest("pkcs12Encrypted", "testdata/certificate_encrypted_key.pfx", "password"),
4445
}
@@ -107,26 +108,29 @@ func TestClientCertificateCredential_GetTokenSuccess_withCertificateChain(t *tes
107108
}
108109
}
109110

110-
func TestClientCertificateCredential_GetTokenSuccess_withCertificateChain_mock(t *testing.T) {
111-
test := allCertTests[0]
112-
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
113-
defer close()
114-
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
115-
srv.AppendResponse(mock.WithBody([]byte(tenantDiscoveryResponse)))
116-
srv.AppendResponse(mock.WithPredicate(validateJWTRequestContainsHeader(t, "x5c")), mock.WithBody([]byte(accessTokenRespSuccess)))
117-
srv.AppendResponse()
111+
func TestClientCertificateCredential_SendCertificateChain(t *testing.T) {
112+
for _, test := range allCertTests {
113+
t.Run(test.name, func(t *testing.T) {
114+
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
115+
defer close()
116+
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
117+
srv.AppendResponse(mock.WithBody([]byte(tenantDiscoveryResponse)))
118+
srv.AppendResponse(mock.WithPredicate(validateX5C(t, test.certs)), mock.WithBody([]byte(accessTokenRespSuccess)))
119+
srv.AppendResponse()
118120

119-
options := ClientCertificateCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}, SendCertificateChain: true}
120-
cred, err := NewClientCertificateCredential(fakeTenantID, fakeClientID, test.certs, test.key, &options)
121-
if err != nil {
122-
t.Fatal(err)
123-
}
124-
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
125-
if err != nil {
126-
t.Fatal(err)
127-
}
128-
if tk.Token != tokenValue {
129-
t.Fatalf("unexpected token: %s", tk.Token)
121+
options := ClientCertificateCredentialOptions{ClientOptions: azcore.ClientOptions{Transport: srv}, SendCertificateChain: true}
122+
cred, err := NewClientCertificateCredential(fakeTenantID, fakeClientID, test.certs, test.key, &options)
123+
if err != nil {
124+
t.Fatal(err)
125+
}
126+
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
127+
if err != nil {
128+
t.Fatal(err)
129+
}
130+
if tk.Token != tokenValue {
131+
t.Fatalf("unexpected token: %s", tk.Token)
132+
}
133+
})
130134
}
131135
}
132136

sdk/azidentity/environment_credential_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,20 @@ func TestEnvironmentCredential_UsernamePasswordSet(t *testing.T) {
195195
}
196196

197197
func TestEnvironmentCredential_SendCertificateChain(t *testing.T) {
198+
certData, err := os.ReadFile(liveSP.pfxPath)
199+
if err != nil {
200+
t.Fatal(err)
201+
}
202+
certs, _, err := ParseCertificates(certData, nil)
203+
if err != nil {
204+
t.Fatal(err)
205+
}
198206
resetEnvironmentVarsForTest()
199207
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
200208
defer close()
201209
srv.AppendResponse(mock.WithBody(instanceDiscoveryResponse))
202210
srv.AppendResponse(mock.WithBody([]byte(tenantDiscoveryResponse)))
203-
srv.AppendResponse(mock.WithPredicate(validateJWTRequestContainsHeader(t, "x5c")), mock.WithBody([]byte(accessTokenRespSuccess)))
211+
srv.AppendResponse(mock.WithPredicate(validateX5C(t, certs)), mock.WithBody([]byte(accessTokenRespSuccess)))
204212
srv.AppendResponse()
205213

206214
vars := map[string]string{

0 commit comments

Comments
 (0)