Skip to content

Commit 4de6fc8

Browse files
authored
Fix issue where x5c header is not sent for OnBehalfOfCredential when SendCertificateChain option is set (Azure#27721)
* Fix SendCertificateChain option for OnBehalfOfCredential
1 parent 983dce7 commit 4de6fc8

File tree

4 files changed

+173
-26
lines changed

4 files changed

+173
-26
lines changed

sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ public virtual async ValueTask<AuthenticationResult> AcquireTokenOnBehalfOf(
170170
{
171171
IConfidentialClientApplication client = await GetClientAsync(async, cancellationToken).ConfigureAwait(false);
172172

173-
var builder = client.AcquireTokenOnBehalfOf(scopes, userAssertionValue);
173+
var builder = client
174+
.AcquireTokenOnBehalfOf(scopes, userAssertionValue)
175+
.WithSendX5C(_includeX5CClaimHeader);
174176

175177
if (!string.IsNullOrEmpty(tenantId))
176178
{

sdk/identity/Azure.Identity/tests/ClientCertificateCredentialTests.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Security.Cryptography.X509Certificates;
77
using System.Threading.Tasks;
88
using Azure.Core;
9+
using Azure.Core.Pipeline;
910
using Azure.Core.TestFramework;
1011
using Azure.Identity.Tests.Mock;
1112
using NUnit.Framework;
@@ -159,5 +160,32 @@ public async Task UsesTenantIdHint(
159160

160161
Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
161162
}
163+
164+
[Test]
165+
public async Task SendCertificateChain([Values(true, false)] bool usePemFile, [Values(true)] bool sendCertChain)
166+
{
167+
TestSetup();
168+
var _transport = Createx5cValidatingTransport(sendCertChain);
169+
var _pipeline = new HttpPipeline(_transport, new[] {new BearerTokenAuthenticationPolicy(new MockCredential(), "scope")});
170+
var context = new TokenRequestContext(new[] { Scope }, tenantId: TenantId);
171+
expectedTenantId = TenantIdResolver.Resolve(TenantId, context);
172+
var certificatePath = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pfx");
173+
var certificatePathPem = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pem");
174+
var mockCert = new X509Certificate2(certificatePath);
175+
options = new ClientCertificateCredentialOptions();
176+
((ClientCertificateCredentialOptions)options).SendCertificateChain = sendCertChain;
177+
178+
ClientCertificateCredential credential = InstrumentClient(
179+
usePemFile
180+
? new ClientCertificateCredential(TenantId, ClientId, certificatePathPem, options,
181+
new CredentialPipeline(new Uri("https://localhost"), _pipeline, new ClientDiagnostics(options)), null)
182+
: new ClientCertificateCredential(TenantId, ClientId, mockCert, options,
183+
new CredentialPipeline(new Uri("https://localhost"), _pipeline, new ClientDiagnostics(options)), null)
184+
);
185+
186+
var token = await credential.GetTokenAsync(context);
187+
188+
Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
189+
}
162190
}
163191
}

sdk/identity/Azure.Identity/tests/CredentialTestBase.cs

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
using System.Globalization;
66
using System.IO;
77
using System.Text;
8+
using System.Text.Json;
89
using System.Threading;
910
using System.Threading.Tasks;
11+
using Azure.Core;
1012
using Azure.Core.TestFramework;
1113
using Azure.Identity.Tests.Mock;
1214
using Microsoft.Identity.Client;
@@ -36,8 +38,12 @@ public class CredentialTestBase : ClientTestBase
3638
protected string expectedCode;
3739
protected DeviceCodeResult deviceCodeResult;
3840

41+
protected const string DiscoveryResponseBody =
42+
"{\"tenant_discovery_endpoint\": \"https://login.microsoftonline.com/c54fac88-3dd3-461f-a7c4-8a368e0340b3/v2.0/.well-known/openid-configuration\",\"api-version\": \"1.1\",\"metadata\":[{\"preferred_network\": \"login.microsoftonline.com\",\"preferred_cache\": \"login.windows.net\",\"aliases\":[\"login.microsoftonline.com\",\"login.windows.net\",\"login.microsoft.com\",\"sts.windows.net\"]},{\"preferred_network\": \"login.partner.microsoftonline.cn\",\"preferred_cache\": \"login.partner.microsoftonline.cn\",\"aliases\":[\"login.partner.microsoftonline.cn\",\"login.chinacloudapi.cn\"]},{\"preferred_network\": \"login.microsoftonline.de\",\"preferred_cache\": \"login.microsoftonline.de\",\"aliases\":[\"login.microsoftonline.de\"]},{\"preferred_network\": \"login.microsoftonline.us\",\"preferred_cache\": \"login.microsoftonline.us\",\"aliases\":[\"login.microsoftonline.us\",\"login.usgovcloudapi.net\"]},{\"preferred_network\": \"login-us.microsoftonline.com\",\"preferred_cache\": \"login-us.microsoftonline.com\",\"aliases\":[\"login-us.microsoftonline.com\"]}]}";
43+
3944
public CredentialTestBase(bool isAsync) : base(isAsync)
40-
{ }
45+
{
46+
}
4147

4248
public void TestSetup()
4349
{
@@ -57,7 +63,7 @@ public void TestSetup()
5763
TenantId,
5864
new MockAccount("username"),
5965
null,
60-
new[] { Scope },
66+
new[] {Scope},
6167
Guid.NewGuid(),
6268
null,
6369
"Bearer");
@@ -103,7 +109,7 @@ public void TestSetup()
103109
TenantId,
104110
new MockAccount("username"),
105111
null,
106-
new[] { Scope },
112+
new[] {Scope},
107113
Guid.NewGuid(),
108114
null,
109115
"Bearer");
@@ -150,6 +156,7 @@ protected async Task<string> ReadMockRequestContent(MockRequest request)
150156
{
151157
return null;
152158
}
159+
153160
using var memoryStream = new MemoryStream();
154161
request.Content.WriteTo(memoryStream, CancellationToken.None);
155162
memoryStream.Position = 0;
@@ -159,7 +166,8 @@ protected async Task<string> ReadMockRequestContent(MockRequest request)
159166
}
160167
}
161168

162-
protected MockResponse CreateMockMsalTokenResponse(int responseCode, string token, string tenantId, string userName)
169+
protected MockResponse CreateMockMsalTokenResponse(int responseCode, string token, string tenantId,
170+
string userName)
163171
{
164172
var response = new MockResponse(responseCode);
165173
var idToken = CreateMsalIdToken(Guid.NewGuid().ToString(), userName, tenantId);
@@ -190,7 +198,7 @@ public static string CreateMsalIdToken(string uniqueId, string displayableId, st
190198
return string.Format(CultureInfo.InvariantCulture, "someheader.{0}.somesignature", MsalEncode(id));
191199
}
192200

193-
private const char base64PadCharacter = '=';
201+
private const char base64PadCharacter = '=';
194202
#if NET45
195203
private const string doubleBase64PadCharacter = "==";
196204
#endif
@@ -204,11 +212,9 @@ public static string CreateMsalIdToken(string uniqueId, string displayableId, st
204212
/// </summary>
205213
internal static readonly char[] s_base64Table =
206214
{
207-
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z',
208-
'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',
209-
'0','1','2','3','4','5','6','7','8','9',
210-
base64UrlCharacter62,
211-
base64UrlCharacter63
215+
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y',
216+
'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
217+
'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', base64UrlCharacter62, base64UrlCharacter63
212218
};
213219

214220
/// <summary>
@@ -302,7 +308,7 @@ private static string MsalEncode(byte[] inArray, int offset, int length)
302308
}
303309
break;
304310

305-
//default or case 0: no further operations are needed.
311+
//default or case 0: no further operations are needed.
306312
}
307313

308314
return new string(output, 0, j);
@@ -323,5 +329,71 @@ public static string MsalEncode(byte[] inArray)
323329

324330
return MsalEncode(inArray, 0, inArray.Length);
325331
}
332+
333+
protected bool RequestBodyHasUserAssertionWithHeader(Request req, string headerName)
334+
{
335+
req.Content.TryComputeLength(out var len);
336+
byte[] content = new byte[len];
337+
var stream = new MemoryStream((int)len);
338+
req.Content.WriteTo(stream, default);
339+
var body = Encoding.UTF8.GetString(stream.GetBuffer(), 0, (int)stream.Length);
340+
var parts = body.Split('&');
341+
foreach (var part in parts)
342+
{
343+
if (part.StartsWith("client_assertion="))
344+
{
345+
var assertion = part.AsSpan();
346+
int start = assertion.IndexOf('=') + 1;
347+
assertion = assertion.Slice(start);
348+
int end = assertion.IndexOf('.');
349+
var jwt = assertion.Slice(0, end);
350+
string convertedToken = jwt.ToString().Replace('_', '/').Replace('-', '+');
351+
switch (jwt.Length % 4)
352+
{
353+
case 2:
354+
convertedToken += "==";
355+
break;
356+
case 3:
357+
convertedToken += "=";
358+
break;
359+
}
360+
361+
Utf8JsonReader reader = new Utf8JsonReader(Convert.FromBase64String(convertedToken));
362+
while (reader.Read())
363+
{
364+
if (reader.TokenType == JsonTokenType.PropertyName)
365+
{
366+
var header = reader.GetString();
367+
if (header == headerName)
368+
{
369+
return true;
370+
}
371+
372+
reader.Read();
373+
}
374+
}
375+
}
376+
}
377+
378+
return false;
379+
}
380+
381+
protected MockTransport Createx5cValidatingTransport(bool sendCertChain) => new MockTransport((req) =>
382+
{
383+
// respond to tenant discovery
384+
if (req.Uri.Path.StartsWith("/common/discovery"))
385+
{
386+
return new MockResponse(200).SetContent(DiscoveryResponseBody);
387+
}
388+
389+
// respond to token request
390+
if (req.Uri.Path.EndsWith("/token"))
391+
{
392+
Assert.That(sendCertChain, Is.EqualTo(RequestBodyHasUserAssertionWithHeader(req, "x5c")));
393+
return new MockResponse(200).WithContent(
394+
$"{{\"token_type\": \"Bearer\",\"expires_in\": 9999,\"ext_expires_in\": 9999,\"access_token\": \"{expectedToken}\" }}");
395+
}
396+
return new MockResponse(200);
397+
});
326398
}
327399
}

sdk/identity/Azure.Identity/tests/OnBehalfOfCredentialTests.cs

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.IO;
56
using System.Security.Cryptography.X509Certificates;
7+
using System.Text;
8+
using System.Text.Json;
69
using System.Threading.Tasks;
710
using Azure.Core;
11+
using Azure.Core.Pipeline;
12+
using Azure.Core.TestFramework;
813
using Azure.Identity.Tests.Mock;
14+
using Microsoft.Diagnostics.Tracing.Parsers.AspNet;
915
using NUnit.Framework;
1016

1117
namespace Azure.Identity.Tests
@@ -25,27 +31,41 @@ public void CtorValidation()
2531
string userAssertion = Guid.NewGuid().ToString();
2632
string clientSecret = Guid.NewGuid().ToString();
2733

28-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, clientSecret, userAssertion, null));
29-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, clientSecret, userAssertion, null));
30-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
31-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, clientSecret, null, null));
34+
Assert.Throws<ArgumentNullException>(() =>
35+
new OnBehalfOfCredential(null, ClientId, clientSecret, userAssertion, null));
36+
Assert.Throws<ArgumentNullException>(() =>
37+
new OnBehalfOfCredential(TenantId, null, clientSecret, userAssertion, null));
38+
Assert.Throws<ArgumentNullException>(() =>
39+
new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
40+
Assert.Throws<ArgumentNullException>(() =>
41+
new OnBehalfOfCredential(TenantId, ClientId, clientSecret, null, null));
3242
cred = new OnBehalfOfCredential(TenantId, ClientId, clientSecret, userAssertion, null);
3343
// Assert
3444
Assert.AreEqual(clientSecret, cred._client._clientSecret);
3545

36-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, _mockCertificate, userAssertion));
37-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, _mockCertificate, userAssertion));
38-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
39-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null));
46+
Assert.Throws<ArgumentNullException>(() =>
47+
new OnBehalfOfCredential(null, ClientId, _mockCertificate, userAssertion));
48+
Assert.Throws<ArgumentNullException>(() =>
49+
new OnBehalfOfCredential(TenantId, null, _mockCertificate, userAssertion));
50+
Assert.Throws<ArgumentNullException>(() =>
51+
new OnBehalfOfCredential(TenantId, ClientId, default(string), userAssertion));
52+
Assert.Throws<ArgumentNullException>(() =>
53+
new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null));
4054
cred = new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, userAssertion);
4155
// Assert
4256
Assert.NotNull(cred._client._certificateProvider);
4357

44-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, _mockCertificate, userAssertion, new OnBehalfOfCredentialOptions()));
45-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, _mockCertificate, userAssertion, new OnBehalfOfCredentialOptions()));
46-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, default(X509Certificate2), userAssertion, new OnBehalfOfCredentialOptions()));
47-
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null, new OnBehalfOfCredentialOptions()));
48-
cred = new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, userAssertion, new OnBehalfOfCredentialOptions());
58+
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(null, ClientId, _mockCertificate,
59+
userAssertion, new OnBehalfOfCredentialOptions()));
60+
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, null, _mockCertificate,
61+
userAssertion, new OnBehalfOfCredentialOptions()));
62+
Assert.Throws<ArgumentNullException>(() => new OnBehalfOfCredential(TenantId, ClientId,
63+
default(X509Certificate2), userAssertion, new OnBehalfOfCredentialOptions()));
64+
Assert.Throws<ArgumentNullException>(() =>
65+
new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, null,
66+
new OnBehalfOfCredentialOptions()));
67+
cred = new OnBehalfOfCredential(TenantId, ClientId, _mockCertificate, userAssertion,
68+
new OnBehalfOfCredentialOptions());
4969
// Assert
5070
Assert.NotNull(cred._client._certificateProvider);
5171
}
@@ -58,7 +78,7 @@ public async Task UsesTenantIdHint(
5878
{
5979
TestSetup();
6080
options = new OnBehalfOfCredentialOptions();
61-
var context = new TokenRequestContext(new[] { Scope }, tenantId: tenantId);
81+
var context = new TokenRequestContext(new[] {Scope}, tenantId: tenantId);
6282
expectedTenantId = TenantIdResolver.Resolve(explicitTenantId, context);
6383
OnBehalfOfCredential client = InstrumentClient(
6484
new OnBehalfOfCredential(
@@ -73,5 +93,30 @@ public async Task UsesTenantIdHint(
7393
var token = await client.GetTokenAsync(new TokenRequestContext(MockScopes.Default), default);
7494
Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
7595
}
96+
97+
[Test]
98+
public async Task SendCertificateChain([Values(true, false)] bool sendCertChain)
99+
{
100+
TestSetup();
101+
var _transport = Createx5cValidatingTransport(sendCertChain);
102+
var _pipeline = new HttpPipeline(_transport, new[] {new BearerTokenAuthenticationPolicy(new MockCredential(), "scope")});
103+
var certificatePath = Path.Combine(TestContext.CurrentContext.TestDirectory, "Data", "cert.pfx");
104+
var mockCert = new X509Certificate2(certificatePath);
105+
106+
options = new OnBehalfOfCredentialOptions();
107+
((OnBehalfOfCredentialOptions)options).SendCertificateChain = sendCertChain;
108+
OnBehalfOfCredential client = InstrumentClient(
109+
new OnBehalfOfCredential(
110+
TenantId,
111+
ClientId,
112+
mockCert,
113+
expectedUserAssertion,
114+
options as OnBehalfOfCredentialOptions,
115+
new CredentialPipeline(new Uri("https://localhost"), _pipeline, new ClientDiagnostics(options)),
116+
null));
117+
118+
var token = await client.GetTokenAsync(new TokenRequestContext(MockScopes.Default), default);
119+
Assert.AreEqual(token.Token, expectedToken, "Should be the expected token value");
120+
}
76121
}
77122
}

0 commit comments

Comments
 (0)