|
3 | 3 |
|
4 | 4 | using System;
|
5 | 5 | using System.Diagnostics;
|
| 6 | +using System.Linq; |
6 | 7 | using System.Net;
|
7 | 8 | using System.Net.Http;
|
| 9 | +using System.Net.Security; |
8 | 10 | using System.Net.Sockets;
|
| 11 | +using System.Security.Claims; |
9 | 12 | using System.Threading;
|
10 | 13 | using System.Threading.Tasks;
|
11 | 14 | using Microsoft.Identity.Client;
|
@@ -1330,5 +1333,89 @@ await mi.AcquireTokenForManagedIdentity(Resource)
|
1330 | 1333 | Assert.AreEqual(httpManager.QueueSize, 0);
|
1331 | 1334 | }
|
1332 | 1335 | }
|
| 1336 | + |
| 1337 | + [TestMethod] |
| 1338 | + public void ValidateServerCertificate_OnlySetForServiceFabric() |
| 1339 | + { |
| 1340 | + using (new EnvVariableContext()) |
| 1341 | + using (var httpManager = new MockHttpManager()) |
| 1342 | + { |
| 1343 | + // Test all managed identity sources |
| 1344 | + foreach (ManagedIdentitySource sourceType in Enum.GetValues(typeof(ManagedIdentitySource)) |
| 1345 | + .Cast<ManagedIdentitySource>() |
| 1346 | + .Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds)) |
| 1347 | + { |
| 1348 | + // Create a managed identity source for each type |
| 1349 | + AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(sourceType, httpManager); |
| 1350 | + |
| 1351 | + // Check if ValidateServerCertificate is set based on the source type |
| 1352 | + bool shouldHaveCallback = sourceType == ManagedIdentitySource.ServiceFabric; |
| 1353 | + bool hasCallback = managedIdentity.GetValidationCallback() != null; |
| 1354 | + |
| 1355 | + Assert.AreEqual( |
| 1356 | + shouldHaveCallback, |
| 1357 | + hasCallback, |
| 1358 | + $"For source type {sourceType}, ValidateServerCertificate should {(shouldHaveCallback ? "" : "not ")}be set"); |
| 1359 | + |
| 1360 | + // For ServiceFabric, verify it's set to the right method |
| 1361 | + if (sourceType == ManagedIdentitySource.ServiceFabric) |
| 1362 | + { |
| 1363 | + Assert.IsNotNull(managedIdentity.GetValidationCallback(), |
| 1364 | + "ServiceFabric should have ValidateServerCertificate set"); |
| 1365 | + |
| 1366 | + Assert.IsInstanceOfType(managedIdentity, typeof(ServiceFabricManagedIdentitySource), |
| 1367 | + "ServiceFabric managed identity should be of type ServiceFabricManagedIdentitySource"); |
| 1368 | + } |
| 1369 | + else |
| 1370 | + { |
| 1371 | + Assert.IsNull(managedIdentity.GetValidationCallback(), |
| 1372 | + $"Non-ServiceFabric source type {sourceType} should not have ValidateServerCertificate set"); |
| 1373 | + } |
| 1374 | + } |
| 1375 | + } |
| 1376 | + } |
| 1377 | + |
| 1378 | + private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySource sourceType, MockHttpManager httpManager) |
| 1379 | + { |
| 1380 | + string endpoint = "https://identity.endpoint.com"; |
| 1381 | + |
| 1382 | + // Setup environment based on the source type |
| 1383 | + SetEnvironmentVariables(sourceType, endpoint); |
| 1384 | + |
| 1385 | + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) |
| 1386 | + .WithHttpManager(httpManager); |
| 1387 | + |
| 1388 | + var managedIdentityApp = miBuilder.BuildConcrete(); |
| 1389 | + RequestContext requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); |
| 1390 | + |
| 1391 | + // Create the correct managed identity source based on the type |
| 1392 | + AbstractManagedIdentity managedIdentity = null; |
| 1393 | + |
| 1394 | + switch (sourceType) |
| 1395 | + { |
| 1396 | + case ManagedIdentitySource.ServiceFabric: |
| 1397 | + managedIdentity = ServiceFabricManagedIdentitySource.Create(requestContext); |
| 1398 | + break; |
| 1399 | + case ManagedIdentitySource.AppService: |
| 1400 | + managedIdentity = AppServiceManagedIdentitySource.Create(requestContext); |
| 1401 | + break; |
| 1402 | + case ManagedIdentitySource.AzureArc: |
| 1403 | + managedIdentity = AzureArcManagedIdentitySource.Create(requestContext); |
| 1404 | + break; |
| 1405 | + case ManagedIdentitySource.CloudShell: |
| 1406 | + managedIdentity = CloudShellManagedIdentitySource.Create(requestContext); |
| 1407 | + break; |
| 1408 | + case ManagedIdentitySource.Imds: |
| 1409 | + managedIdentity = new ImdsManagedIdentitySource(requestContext); |
| 1410 | + break; |
| 1411 | + case ManagedIdentitySource.MachineLearning: |
| 1412 | + managedIdentity = MachineLearningManagedIdentitySource.Create(requestContext); |
| 1413 | + break; |
| 1414 | + default: |
| 1415 | + throw new NotSupportedException($"Unsupported managed identity source type: {sourceType}"); |
| 1416 | + } |
| 1417 | + |
| 1418 | + return managedIdentity; |
| 1419 | + } |
1333 | 1420 | }
|
1334 | 1421 | }
|
0 commit comments