Skip to content

Commit 428838e

Browse files
Fix regression in ManagedIdentity flows (#5292)
* Fix regression * Address comments
1 parent ce722e2 commit 428838e

File tree

4 files changed

+114
-13
lines changed

4 files changed

+114
-13
lines changed

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
using System.Net;
1313
using Microsoft.Identity.Client.ApiConfig.Parameters;
1414
using System.Text;
15+
using System.Security.Cryptography.X509Certificates;
16+
using System.Net.Security;
17+
1518
#if SUPPORTS_SYSTEM_TEXT_JSON
1619
using System.Text.Json;
1720
#else
@@ -22,11 +25,13 @@ namespace Microsoft.Identity.Client.ManagedIdentity
2225
{
2326
internal abstract class AbstractManagedIdentity
2427
{
28+
private const string ManagedIdentityPrefix = "[Managed Identity] ";
29+
2530
protected readonly RequestContext _requestContext;
31+
2632
internal const string TimeoutError = "[Managed Identity] Authentication unavailable. The request to the managed identity endpoint timed out.";
2733
internal readonly ManagedIdentitySource _sourceType;
28-
private const string ManagedIdentityPrefix = "[Managed Identity] ";
29-
34+
3035
protected AbstractManagedIdentity(RequestContext requestContext, ManagedIdentitySource sourceType)
3136
{
3237
_requestContext = requestContext;
@@ -65,7 +70,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
6570
logger: _requestContext.Logger,
6671
doNotThrow: true,
6772
mtlsCertificate: null,
68-
validateServerCertificate: ValidateServerCertificate,
73+
validateServerCertificate: GetValidationCallback(),
6974
cancellationToken: cancellationToken,
7075
retryPolicy: request.RetryPolicy).ConfigureAwait(false);
7176
}
@@ -80,7 +85,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
8085
logger: _requestContext.Logger,
8186
doNotThrow: true,
8287
mtlsCertificate: null,
83-
validateServerCertificate: ValidateServerCertificate,
88+
validateServerCertificate: GetValidationCallback(),
8489
cancellationToken: cancellationToken,
8590
retryPolicy: request.RetryPolicy)
8691
.ConfigureAwait(false);
@@ -96,13 +101,14 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
96101
}
97102
}
98103

99-
// This method is used to validate the server certificate.
100-
// It is overridden in the Service Fabric managed identity source to validate the certificate thumbprint.
101-
// The default implementation always returns true.
102-
internal virtual bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
103-
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
104+
/// <summary>
105+
/// Method to be overridden in the derived classes to provide a custom validation callback for the server certificate.
106+
/// This validation is needed for service fabric managed identity endpoints.
107+
/// </summary>
108+
/// <returns>Callback to validate the server certificate.</returns>
109+
internal virtual Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> GetValidationCallback()
104110
{
105-
return true;
111+
return null;
106112
}
107113

108114
protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(

src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Globalization;
66
using System.Net.Http;
77
using System.Net.Security;
8+
using System.Security.Cryptography.X509Certificates;
89
using Microsoft.Identity.Client.Core;
910
using Microsoft.Identity.Client.Internal;
1011

@@ -15,6 +16,7 @@ internal class ServiceFabricManagedIdentitySource : AbstractManagedIdentity
1516
private const string ServiceFabricMsiApiVersion = "2019-07-01-preview";
1617
private readonly Uri _endpoint;
1718
private readonly string _identityHeaderValue;
19+
1820
internal static Lazy<HttpClient> _httpClientLazy;
1921

2022
public static AbstractManagedIdentity Create(RequestContext requestContext)
@@ -40,11 +42,17 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
4042
}
4143

4244
requestContext.Logger.Verbose(() => "[Managed Identity] Creating Service Fabric managed identity. Endpoint URI: " + identityEndpoint);
45+
4346
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader);
4447
}
4548

46-
internal override bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
47-
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
49+
internal override Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> GetValidationCallback()
50+
{
51+
return ValidateServerCertificateCallback;
52+
}
53+
54+
private bool ValidateServerCertificateCallback(HttpRequestMessage message, X509Certificate2 certificate,
55+
X509Chain chain, SslPolicyErrors sslPolicyErrors)
4856
{
4957
if (sslPolicyErrors == SslPolicyErrors.None)
5058
{

tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
using System;
55
using System.Diagnostics;
6+
using System.Linq;
67
using System.Net;
78
using System.Net.Http;
9+
using System.Net.Security;
810
using System.Net.Sockets;
11+
using System.Security.Claims;
912
using System.Threading;
1013
using System.Threading.Tasks;
1114
using Microsoft.Identity.Client;
@@ -1330,5 +1333,89 @@ await mi.AcquireTokenForManagedIdentity(Resource)
13301333
Assert.AreEqual(httpManager.QueueSize, 0);
13311334
}
13321335
}
1336+
1337+
[TestMethod]
1338+
public void ValidateServerCertificate_OnlySetForServiceFabric()
1339+
{
1340+
using (new EnvVariableContext())
1341+
using (var httpManager = new MockHttpManager())
1342+
{
1343+
// Test all managed identity sources
1344+
foreach (ManagedIdentitySource sourceType in Enum.GetValues(typeof(ManagedIdentitySource))
1345+
.Cast<ManagedIdentitySource>()
1346+
.Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds))
1347+
{
1348+
// Create a managed identity source for each type
1349+
AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(sourceType, httpManager);
1350+
1351+
// Check if ValidateServerCertificate is set based on the source type
1352+
bool shouldHaveCallback = sourceType == ManagedIdentitySource.ServiceFabric;
1353+
bool hasCallback = managedIdentity.GetValidationCallback() != null;
1354+
1355+
Assert.AreEqual(
1356+
shouldHaveCallback,
1357+
hasCallback,
1358+
$"For source type {sourceType}, ValidateServerCertificate should {(shouldHaveCallback ? "" : "not ")}be set");
1359+
1360+
// For ServiceFabric, verify it's set to the right method
1361+
if (sourceType == ManagedIdentitySource.ServiceFabric)
1362+
{
1363+
Assert.IsNotNull(managedIdentity.GetValidationCallback(),
1364+
"ServiceFabric should have ValidateServerCertificate set");
1365+
1366+
Assert.IsInstanceOfType(managedIdentity, typeof(ServiceFabricManagedIdentitySource),
1367+
"ServiceFabric managed identity should be of type ServiceFabricManagedIdentitySource");
1368+
}
1369+
else
1370+
{
1371+
Assert.IsNull(managedIdentity.GetValidationCallback(),
1372+
$"Non-ServiceFabric source type {sourceType} should not have ValidateServerCertificate set");
1373+
}
1374+
}
1375+
}
1376+
}
1377+
1378+
private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySource sourceType, MockHttpManager httpManager)
1379+
{
1380+
string endpoint = "https://identity.endpoint.com";
1381+
1382+
// Setup environment based on the source type
1383+
SetEnvironmentVariables(sourceType, endpoint);
1384+
1385+
var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
1386+
.WithHttpManager(httpManager);
1387+
1388+
var managedIdentityApp = miBuilder.BuildConcrete();
1389+
RequestContext requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null);
1390+
1391+
// Create the correct managed identity source based on the type
1392+
AbstractManagedIdentity managedIdentity = null;
1393+
1394+
switch (sourceType)
1395+
{
1396+
case ManagedIdentitySource.ServiceFabric:
1397+
managedIdentity = ServiceFabricManagedIdentitySource.Create(requestContext);
1398+
break;
1399+
case ManagedIdentitySource.AppService:
1400+
managedIdentity = AppServiceManagedIdentitySource.Create(requestContext);
1401+
break;
1402+
case ManagedIdentitySource.AzureArc:
1403+
managedIdentity = AzureArcManagedIdentitySource.Create(requestContext);
1404+
break;
1405+
case ManagedIdentitySource.CloudShell:
1406+
managedIdentity = CloudShellManagedIdentitySource.Create(requestContext);
1407+
break;
1408+
case ManagedIdentitySource.Imds:
1409+
managedIdentity = new ImdsManagedIdentitySource(requestContext);
1410+
break;
1411+
case ManagedIdentitySource.MachineLearning:
1412+
managedIdentity = MachineLearningManagedIdentitySource.Create(requestContext);
1413+
break;
1414+
default:
1415+
throw new NotSupportedException($"Unsupported managed identity source type: {sourceType}");
1416+
}
1417+
1418+
return managedIdentity;
1419+
}
13331420
}
13341421
}

tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public void ValidateServerCertificateCallback_ServerCertificateValidationCallbac
8686
var sf = ServiceFabricManagedIdentitySource.Create(requestContext);
8787

8888
Assert.IsInstanceOfType(sf, typeof(ServiceFabricManagedIdentitySource));
89-
var callback = ((ServiceFabricManagedIdentitySource)sf).ValidateServerCertificate(null, certificate, chain, sslPolicyErrors);
89+
var callback = sf.GetValidationCallback();
9090
Assert.IsNotNull(callback);
9191
}
9292
}

0 commit comments

Comments
 (0)