Skip to content

Commit ba1d357

Browse files
Update the SF flow to get a http client from the factory (#5221)
* Update the SF flow to get a http client from the factory * Update the method to pass callback instead of handler * Fix build failures * Address comments * Address comments * Fix tests
1 parent d88df71 commit ba1d357

File tree

24 files changed

+383
-191
lines changed

24 files changed

+383
-191
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Net.Http;
6+
using System.Net.Security;
7+
using System.Security.Cryptography.X509Certificates;
8+
9+
namespace Microsoft.Identity.Client
10+
{
11+
/// <summary>
12+
/// Factory responsible for creating HttpClient with a custom server certificate validation callback.
13+
/// This is useful for the Service Fabric scenario where the server certificate validation is required using the server cert.
14+
/// See https://learn.microsoft.com/dotnet/api/system.net.http.httpclient?view=net-7.0#instancing for more details.
15+
/// </summary>
16+
/// <remarks>
17+
/// Implementations must be thread safe.
18+
/// Do not create a new HttpClient for each call to <see cref="GetHttpClient"/> - this leads to socket exhaustion.
19+
/// If your app uses Integrated Windows Authentication, ensure <see cref="HttpClientHandler.UseDefaultCredentials"/> is set to true.
20+
/// </remarks>
21+
public interface IMsalSFHttpClientFactory : IMsalHttpClientFactory
22+
{
23+
24+
/// <summary>
25+
/// Method returning an HTTP client that will be used to validate the server certificate through the provided callback.
26+
/// This method is useful when custom certificate validation logic is required,
27+
/// for the managed identity flow running on a service fabric cluster.
28+
/// </summary>
29+
/// <param name="validateServerCert">Callback to validate the server certificate for the Service Fabric.</param>
30+
/// <returns>An HTTP client configured with the provided server certificate validation callback.</returns>
31+
HttpClient GetHttpClient(Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert);
32+
}
33+
}

src/client/Microsoft.Identity.Client/Http/HttpManager.cs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.IO;
99
using System.Net;
1010
using System.Net.Http;
11+
using System.Net.Security;
1112
using System.Security.Cryptography.X509Certificates;
1213
using System.Threading;
1314
using System.Threading.Tasks;
@@ -51,8 +52,8 @@ public async Task<HttpResponse> SendRequestAsync(
5152
ILoggerAdapter logger,
5253
bool doNotThrow,
5354
X509Certificate2 bindingCertificate,
54-
HttpClient customHttpClient,
55-
CancellationToken cancellationToken,
55+
Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert,
56+
CancellationToken cancellationToken,
5657
int retryCount = 0)
5758
{
5859
Exception timeoutException = null;
@@ -76,8 +77,7 @@ public async Task<HttpResponse> SendRequestAsync(
7677
clonedBody,
7778
method,
7879
bindingCertificate,
79-
customHttpClient,
80-
logger,
80+
validateServerCert, logger,
8181
cancellationToken).ConfigureAwait(false);
8282
}
8383

@@ -113,9 +113,8 @@ public async Task<HttpResponse> SendRequestAsync(
113113
logger,
114114
doNotThrow,
115115
bindingCertificate,
116-
customHttpClient,
117-
cancellationToken: cancellationToken,
118-
retryCount) // Pass the updated retry count
116+
validateServerCert, cancellationToken: cancellationToken,
117+
retryCount: retryCount) // Pass the updated retry count
119118
.ConfigureAwait(false);
120119
}
121120

@@ -146,15 +145,32 @@ public async Task<HttpResponse> SendRequestAsync(
146145
return response;
147146
}
148147

149-
private HttpClient GetHttpClient(X509Certificate2 x509Certificate2, HttpClient customHttpClient) {
150-
if (x509Certificate2 != null && customHttpClient != null)
148+
private HttpClient GetHttpClient(X509Certificate2 x509Certificate2, Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert)
149+
{
150+
if (x509Certificate2 != null && validateServerCert != null)
151151
{
152152
throw new NotImplementedException("Mtls certificate cannot be used with service fabric. A custom http client is used for service fabric managed identity to validate the server certificate.");
153153
}
154154

155-
if (customHttpClient != null)
155+
if (validateServerCert != null)
156156
{
157-
return customHttpClient;
157+
// If the factory is an IMsalSFHttpClientFactory, use it to get an HttpClient with the custom handler
158+
// that validates the server certificate.
159+
if (_httpClientFactory is IMsalSFHttpClientFactory msalSFHttpClientFactory)
160+
{
161+
return msalSFHttpClientFactory.GetHttpClient(validateServerCert);
162+
}
163+
164+
#if NET471_OR_GREATER || NETSTANDARD || NET
165+
// If the factory is not an IMsalSFHttpClientFactory, use it to get a default HttpClient
166+
return new HttpClient(new HttpClientHandler()
167+
{
168+
169+
ServerCertificateCustomValidationCallback = validateServerCert
170+
});
171+
#else
172+
return _httpClientFactory.GetHttpClient();
173+
#endif
158174
}
159175

160176
if (_httpClientFactory is IMsalMtlsHttpClientFactory msalMtlsHttpClientFactory)
@@ -188,7 +204,7 @@ private async Task<HttpResponse> ExecuteAsync(
188204
HttpContent body,
189205
HttpMethod method,
190206
X509Certificate2 bindingCertificate,
191-
HttpClient customHttpClient,
207+
Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert,
192208
ILoggerAdapter logger,
193209
CancellationToken cancellationToken = default)
194210
{
@@ -203,7 +219,7 @@ private async Task<HttpResponse> ExecuteAsync(
203219

204220
Stopwatch sw = Stopwatch.StartNew();
205221

206-
HttpClient client = GetHttpClient(bindingCertificate, customHttpClient);
222+
HttpClient client = GetHttpClient(bindingCertificate, validateServerCert);
207223

208224
using (HttpResponseMessage responseMessage =
209225
await client.SendAsync(requestMessage, cancellationToken).ConfigureAwait(false))

src/client/Microsoft.Identity.Client/Http/IHttpManager.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
using System;
55
using System.Collections.Generic;
66
using System.Net.Http;
7+
using System.Net.Security;
78
using System.Security.Cryptography.X509Certificates;
89
using System.Threading;
910
using System.Threading.Tasks;
1011
using Microsoft.Identity.Client.Core;
11-
using Microsoft.Identity.Client.Internal;
1212

1313
namespace Microsoft.Identity.Client.Http
1414
{
@@ -26,8 +26,7 @@ internal interface IHttpManager
2626
/// <param name="logger">Logger from the request context.</param>
2727
/// <param name="doNotThrow">Flag to decide if MsalServiceException is thrown or the response is returned in case of 5xx errors.</param>
2828
/// <param name="mtlsCertificate">Certificate used for MTLS authentication.</param>
29-
/// <param name="customHttpClient">Custom http client which bypasses the HttpClientFactory.
30-
/// This is needed for service fabric managed identity where a cert validation callback is added to the handler.</param>
29+
/// <param name="validateServerCertificate">Callback to validate the server cert for service fabric managed identity flow.</param>
3130
/// <param name="cancellationToken"></param>
3231
/// <param name="retryCount">Number of retries to be attempted in case of retriable status codes.</param>
3332
/// <returns></returns>
@@ -39,7 +38,7 @@ Task<HttpResponse> SendRequestAsync(
3938
ILoggerAdapter logger,
4039
bool doNotThrow,
4140
X509Certificate2 mtlsCertificate,
42-
HttpClient customHttpClient,
41+
Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCertificate,
4342
CancellationToken cancellationToken,
4443
int retryCount = 0);
4544
}

src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,14 @@ private async Task<RegionInfo> DiscoverAsync(ILoggerAdapter logger, Cancellation
199199
Uri imdsUri = BuildImdsUri(DefaultApiVersion);
200200

201201
HttpResponse response = await _httpManager.SendRequestAsync(
202-
imdsUri,
203-
headers,
204-
body: null,
205-
HttpMethod.Get,
206-
logger: logger,
207-
doNotThrow: false,
208-
mtlsCertificate: null,
209-
customHttpClient: null,
210-
GetCancellationToken(requestCancellationToken))
202+
imdsUri,
203+
headers,
204+
body: null,
205+
method: HttpMethod.Get,
206+
logger: logger,
207+
doNotThrow: false,
208+
mtlsCertificate: null,
209+
validateServerCertificate: null, cancellationToken: GetCancellationToken(requestCancellationToken))
211210
.ConfigureAwait(false);
212211

213212
// A bad request occurs when the version in the IMDS call is no longer supported.
@@ -219,12 +218,11 @@ private async Task<RegionInfo> DiscoverAsync(ILoggerAdapter logger, Cancellation
219218
imdsUri,
220219
headers,
221220
body: null,
222-
HttpMethod.Get,
221+
method: HttpMethod.Get,
223222
logger: logger,
224223
doNotThrow: false,
225224
mtlsCertificate: null,
226-
customHttpClient: null,
227-
GetCancellationToken(requestCancellationToken))
225+
validateServerCertificate: null, cancellationToken: GetCancellationToken(requestCancellationToken))
228226
.ConfigureAwait(false); // Call again with updated version
229227
}
230228

@@ -318,16 +316,16 @@ private async Task<string> GetImdsUriApiVersionAsync(ILoggerAdapter logger, Dict
318316
Uri imdsErrorUri = new(ImdsEndpoint);
319317

320318
HttpResponse response = await _httpManager.SendRequestAsync(
321-
imdsErrorUri,
322-
headers,
323-
body: null,
324-
HttpMethod.Get,
325-
logger: logger,
326-
doNotThrow: false,
327-
mtlsCertificate: null,
328-
customHttpClient: null,
329-
GetCancellationToken(userCancellationToken))
330-
.ConfigureAwait(false);
319+
imdsErrorUri,
320+
headers,
321+
body: null,
322+
method: HttpMethod.Get,
323+
logger: logger,
324+
doNotThrow: false,
325+
mtlsCertificate: null,
326+
validateServerCertificate: null,
327+
cancellationToken: GetCancellationToken(userCancellationToken))
328+
.ConfigureAwait(false);
331329

332330
// When IMDS endpoint is called without the api version query param, bad request response comes back with latest version.
333331
if (response.StatusCode == HttpStatusCode.BadRequest)

src/client/Microsoft.Identity.Client/Instance/Validation/AdfsAuthorityValidator.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the MIT License.
33

44
using System;
5-
using System.Globalization;
65
using System.Linq;
76
using System.Net;
87
using System.Threading.Tasks;
@@ -33,12 +32,11 @@ public async Task ValidateAuthorityAsync(
3332
new Uri(webFingerUrl),
3433
null,
3534
body: null,
36-
System.Net.Http.HttpMethod.Get,
35+
method: System.Net.Http.HttpMethod.Get,
3736
logger: _requestContext.Logger,
3837
doNotThrow: false,
3938
mtlsCertificate: null,
40-
customHttpClient: null,
41-
_requestContext.UserCancellationToken)
39+
validateServerCertificate: null, cancellationToken: _requestContext.UserCancellationToken)
4240
.ConfigureAwait(false);
4341

4442
if (httpResponse.StatusCode != HttpStatusCode.OK)

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
6161
request.ComputeUri(),
6262
request.Headers,
6363
body: null,
64-
HttpMethod.Get,
64+
method: HttpMethod.Get,
6565
logger: _requestContext.Logger,
6666
doNotThrow: true,
6767
mtlsCertificate: null,
68-
GetHttpClientWithSslValidation(_requestContext),
69-
cancellationToken).ConfigureAwait(false);
68+
validateServerCertificate: null, cancellationToken: cancellationToken).ConfigureAwait(false);
7069
}
7170
else
7271
{
@@ -75,12 +74,11 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
7574
request.ComputeUri(),
7675
request.Headers,
7776
body: new FormUrlEncodedContent(request.BodyParameters),
78-
HttpMethod.Post,
77+
method: HttpMethod.Post,
7978
logger: _requestContext.Logger,
8079
doNotThrow: true,
8180
mtlsCertificate: null,
82-
GetHttpClientWithSslValidation(_requestContext),
83-
cancellationToken)
81+
validateServerCertificate: null, cancellationToken: cancellationToken)
8482
.ConfigureAwait(false);
8583

8684
}
@@ -94,10 +92,13 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
9492
}
9593
}
9694

97-
// This method is internal for testing purposes.
98-
internal virtual HttpClient GetHttpClientWithSslValidation(RequestContext requestContext)
95+
// This method is used to validate the server certificate.
96+
// It is overridden in the Service Fabric managed identity source to validate the certificate thumbprint.
97+
// The default implementation always returns true.
98+
internal virtual bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
99+
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
99100
{
100-
return null;
101+
return true;
101102
}
102103

103104
protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(

src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
using System;
55
using System.Globalization;
66
using System.IO;
7-
using System.Linq;
87
using System.Threading;
98
using System.Threading.Tasks;
109
using Microsoft.Identity.Client.ApiConfig.Parameters;
1110
using Microsoft.Identity.Client.Core;
12-
using Microsoft.Identity.Client.Extensibility;
1311
using Microsoft.Identity.Client.Http;
1412
using Microsoft.Identity.Client.Internal;
1513
using Microsoft.Identity.Client.PlatformsCommon.Shared;
16-
using Microsoft.Identity.Client.Utils;
1714

1815
namespace Microsoft.Identity.Client.ManagedIdentity
1916
{
@@ -127,16 +124,16 @@ protected override async Task<ManagedIdentityResponse> HandleResponseAsync(
127124
request.Headers.Add("Authorization", authHeaderValue);
128125

129126
response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync(
130-
request.ComputeUri(),
131-
request.Headers,
132-
body: null,
133-
System.Net.Http.HttpMethod.Get,
134-
logger: _requestContext.Logger,
135-
doNotThrow: false,
136-
mtlsCertificate: null,
137-
customHttpClient: null,
138-
cancellationToken)
139-
.ConfigureAwait(false);
127+
request.ComputeUri(),
128+
request.Headers,
129+
body: null,
130+
method: System.Net.Http.HttpMethod.Get,
131+
logger: _requestContext.Logger,
132+
doNotThrow: false,
133+
mtlsCertificate: null,
134+
validateServerCertificate: null,
135+
cancellationToken: cancellationToken)
136+
.ConfigureAwait(false);
140137

141138
return await base.HandleResponseAsync(parameters, response, cancellationToken).ConfigureAwait(false);
142139
}

src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Globalization;
66
using System.Net.Http;
7+
using System.Net.Security;
78
using Microsoft.Identity.Client.Core;
89
using Microsoft.Identity.Client.Internal;
910

@@ -42,42 +43,17 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
4243
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader);
4344
}
4445

45-
internal override HttpClient GetHttpClientWithSslValidation(RequestContext requestContext)
46+
internal override bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
47+
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
4648
{
47-
if (_httpClientLazy == null)
49+
if (sslPolicyErrors == SslPolicyErrors.None)
4850
{
49-
_httpClientLazy = new Lazy<HttpClient>(() =>
50-
{
51-
HttpClientHandler handler = CreateHandlerWithSslValidation(requestContext.Logger);
52-
return new HttpClient(handler);
53-
});
51+
return true;
5452
}
5553

56-
return _httpClientLazy.Value;
54+
return string.Equals(certificate.GetCertHashString(), EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase);
5755
}
5856

59-
internal HttpClientHandler CreateHandlerWithSslValidation(ILoggerAdapter logger)
60-
{
61-
#if NET471_OR_GREATER || NETSTANDARD || NET
62-
logger.Info(() => "[Managed Identity] Setting up server certificate validation callback.");
63-
return new HttpClientHandler
64-
{
65-
ServerCertificateCustomValidationCallback = (message, certificate, chain, sslPolicyErrors) =>
66-
{
67-
if (sslPolicyErrors != System.Net.Security.SslPolicyErrors.None)
68-
{
69-
return 0 == string.Compare(certificate.Thumbprint, EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase);
70-
}
71-
return true;
72-
}
73-
};
74-
#else
75-
logger.Warning("[Managed Identity] Server certificate validation callback is not supported on .NET Framework.");
76-
return new HttpClientHandler();
77-
#endif
78-
}
79-
80-
8157
private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
8258
base(requestContext, ManagedIdentitySource.ServiceFabric)
8359
{

0 commit comments

Comments
 (0)