From 5498968b64d291e6fc73645e17e949b84a9cf41b Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 12:12:25 -0400 Subject: [PATCH 01/41] Initial commit. 2 TODOs --- .../ManagedIdentity/CsrRequest.cs | 41 ++++++++++ .../ManagedIdentity/CsrRequestResponse.cs | 53 +++++++++++++ .../ImdsV2ManagedIdentitySource.cs | 74 ++++++++++++++++++- .../ManagedIdentityTests/ImdsV2Tests.cs | 2 + 4 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs new file mode 100644 index 0000000000..eda86d9325 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal class CsrRequest + { + public string Pem { get; } + + public CsrRequest(string pem) + { + Pem = pem ?? throw new ArgumentNullException(nameof(pem)); + } + + /// + /// Generates a CSR for the given client, tenant, and CUID info. + /// + /// Managed Identity client_id. + /// AAD tenant_id. + /// CuidInfo object containing VMID and VMSSID. + /// CsrRequest containing the PEM CSR. + public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cuid) + { + if (string.IsNullOrWhiteSpace(clientId)) + throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); + if (string.IsNullOrWhiteSpace(tenantId)) + throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); + if (cuid == null) + throw new ArgumentNullException(nameof(cuid)); + if (string.IsNullOrWhiteSpace(cuid.Vmid)) + throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); + if (string.IsNullOrWhiteSpace(cuid.Vmssid)) + throw new ArgumentException("cuid.Vmssid must not be null or empty.", nameof(cuid.Vmssid)); + + // TODO: Implement the actual CSR generation logic. + return new CsrRequest("pem"); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs new file mode 100644 index 0000000000..10274e48ba --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if SUPPORTS_SYSTEM_TEXT_JSON + using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +#else +using Microsoft.Identity.Client.Utils; +using Microsoft.Identity.Json; +#endif + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Represents the response for a Managed Identity CSR request. + /// + internal class CsrRequestResponse + { + [JsonProperty("client_id")] + public string ClientId { get; } + + [JsonProperty("tenant_id")] + public string TenantId { get; } + + [JsonProperty("client_credential")] + public string ClientCredential { get; } + + [JsonProperty("regional_token_url")] + public string RegionalTokenUrl { get; } + + [JsonProperty("expires_in")] + public int ExpiresIn { get; } + + [JsonProperty("refresh_in")] + public int RefreshIn { get; } + + public CsrRequestResponse() { } + + public static bool ValidateCsrRequestResponse(CsrRequestResponse csrRequestResponse) + { + if (string.IsNullOrEmpty(csrRequestResponse.ClientId) || + string.IsNullOrEmpty(csrRequestResponse.TenantId) || + string.IsNullOrEmpty(csrRequestResponse.ClientCredential) || + string.IsNullOrEmpty(csrRequestResponse.RegionalTokenUrl) || + csrRequestResponse.ExpiresIn <= 0 || + csrRequestResponse.RefreshIn <= 0) + { + return false; + } + + return true; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 9db03cc298..6fef52849b 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Net; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; @@ -16,6 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; + private const string CsrRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -29,7 +31,7 @@ public static async Task GetCsrMetadataAsync( requestContext.Logger); if (userAssignedIdQueryParam != null) { - queryParams += $"{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; + queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; } var headers = new Dictionary @@ -41,7 +43,6 @@ public static async Task GetCsrMetadataAsync( IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.CsrMetadataProbe); - // CSR metadata GET request HttpResponse response = null; try @@ -50,7 +51,7 @@ public static async Task GetCsrMetadataAsync( ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrMetadataPath, queryParams), headers, body: null, - method: System.Net.Http.HttpMethod.Get, + method: HttpMethod.Get, logger: requestContext.Logger, doNotThrow: false, mtlsCertificate: null, @@ -194,8 +195,75 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } + private async Task ExecuteCsrRequestAsync( + RequestContext requestContext, + string queryParams, + string pem) + { + var headers = new Dictionary + { + { "Metadata", "true" }, + { "x-ms-client-request-id", requestContext.CorrelationId.ToString() } + }; + + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); + + HttpResponse response = null; + + try + { + response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrRequestPath, queryParams), + headers, + body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), + method: HttpMethod.Post, + logger: requestContext.Logger, + doNotThrow: false, + mtlsCertificate: null, + validateServerCertificate: null, + cancellationToken: requestContext.UserCancellationToken, + retryPolicy: retryPolicy) + .ConfigureAwait(false); + } + catch (Exception ex) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCsrRequest failed.", + ex, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + var csrRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!CsrRequestResponse.ValidateCsrRequestResponse(csrRequestResponse)) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the CsrMetadata response is invalid. Status code: {response.StatusCode} Body: {response.Body}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + return csrRequestResponse; + } + protected override ManagedIdentityRequest CreateRequest(string resource) { + var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); + var csrRequest = CsrRequest.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + + var queryParams = $"cid={csrMetadata.Cuid}"; + if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) + { + queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; + } + queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; + + var csrRequestResponse = ExecuteCsrRequestAsync(_requestContext, queryParams, csrRequest.Pem); + throw new NotImplementedException(); } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index e1aea27aa4..4bf53c7a42 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -130,5 +130,7 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } } + + // TODO: Create CSR generation unit tests } } From 6bc21644d3fca39d43d4e04b4640161b98c4c3aa Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 16:50:16 -0400 Subject: [PATCH 02/41] Implemented CSR generator --- .../ManagedIdentity/CsrRequest.cs | 402 +++++++++++++++++- .../ManagedIdentityTests/ImdsV2Tests.cs | 21 +- 2 files changed, 420 insertions(+), 3 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs index eda86d9325..aa692a5b0f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using System; +using System.Security.Cryptography; +using System.Text; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -34,8 +36,404 @@ public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cui if (string.IsNullOrWhiteSpace(cuid.Vmssid)) throw new ArgumentException("cuid.Vmssid must not be null or empty.", nameof(cuid.Vmssid)); - // TODO: Implement the actual CSR generation logic. - return new CsrRequest("pem"); + string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); + return new CsrRequest(pemCsr); } + + /// + /// Generates a PKCS#10 Certificate Signing Request in PEM format. + /// + private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidInfo cuid) + { + // Generate RSA key pair (2048-bit) + RSA rsa = CreateRsaKeyPair(); + + try + { + // Build the CSR components + byte[] certificationRequestInfo = BuildCertificationRequestInfo(clientId, tenantId, cuid, rsa); + byte[] signatureAlgorithm = BuildSignatureAlgorithmIdentifier(); + byte[] signature = SignCertificationRequestInfo(certificationRequestInfo, rsa); + + // Combine into final CSR structure + byte[] csrBytes = BuildFinalCsr(certificationRequestInfo, signatureAlgorithm, signature); + + // Convert to PEM format + return ConvertToPem(csrBytes); + } + finally + { + rsa?.Dispose(); + } + } + + /// + /// Creates a 2048-bit RSA key pair compatible with all target frameworks. + /// + private static RSA CreateRsaKeyPair() + { +#if NET462 || NET472 + var rsa = new RSACryptoServiceProvider(2048); + return rsa; +#else + var rsa = RSA.Create(); + rsa.KeySize = 2048; + return rsa; +#endif + } + + /// + /// Builds the CertificationRequestInfo structure containing subject, public key, and attributes. + /// + private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) + { + var components = new System.Collections.Generic.List(); + + // Version (INTEGER 0) + components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); + + // Subject: CN=, DC= + components.Add(BuildSubjectName(clientId, tenantId)); + + // SubjectPublicKeyInfo + components.Add(BuildSubjectPublicKeyInfo(rsa)); + + // Attributes (including CUID) + components.Add(BuildAttributes(cuid)); + + return EncodeAsn1Sequence(components.ToArray()); + } + + /// + /// Builds the X.500 Distinguished Name for the subject. + /// + private static byte[] BuildSubjectName(string clientId, string tenantId) + { + var rdnSequence = new System.Collections.Generic.List(); + + // CN= + byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID + byte[] cnValue = EncodeAsn1Utf8String(clientId); + byte[] cnAttributeValue = EncodeAsn1Sequence(new[] { cnOid, cnValue }); + rdnSequence.Add(EncodeAsn1Set(new[] { cnAttributeValue })); + + // DC= + byte[] dcOid = EncodeAsn1ObjectIdentifier(new int[] { 0, 9, 2342, 19200300, 100, 1, 25 }); // domainComponent OID + byte[] dcValue = EncodeAsn1Utf8String(tenantId); + byte[] dcAttributeValue = EncodeAsn1Sequence(new[] { dcOid, dcValue }); + rdnSequence.Add(EncodeAsn1Set(new[] { dcAttributeValue })); + + return EncodeAsn1Sequence(rdnSequence.ToArray()); + } + + /// + /// Builds the SubjectPublicKeyInfo structure containing the RSA public key. + /// + private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) + { + RSAParameters rsaParams = rsa.ExportParameters(false); + + // RSA Public Key structure + byte[] modulus = EncodeAsn1Integer(rsaParams.Modulus); + byte[] exponent = EncodeAsn1Integer(rsaParams.Exponent); + byte[] rsaPublicKey = EncodeAsn1Sequence(new[] { modulus, exponent }); + + // Algorithm identifier for RSA encryption + byte[] rsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 1 }); // RSA encryption OID + byte[] nullParam = EncodeAsn1Null(); + byte[] algorithmIdentifier = EncodeAsn1Sequence(new[] { rsaOid, nullParam }); + + // SubjectPublicKeyInfo + byte[] publicKeyBitString = EncodeAsn1BitString(rsaPublicKey); + return EncodeAsn1Sequence(new[] { algorithmIdentifier, publicKeyBitString }); + } + + /// + /// Builds the attributes section including the CUID extension. + /// + private static byte[] BuildAttributes(CuidInfo cuid) + { + var attributes = new System.Collections.Generic.List(); + + // CUID attribute (OID 1.2.840.113549.1.9.7) + byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); + string cuidValue = $"{cuid.Vmid}:{cuid.Vmssid}"; + byte[] cuidData = EncodeAsn1PrintableString(cuidValue); + byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); + byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); + attributes.Add(cuidAttribute); + + return EncodeAsn1ContextSpecific(0, EncodeAsn1SequenceRaw(attributes.ToArray())); + } + + /// + /// Builds the signature algorithm identifier for SHA256withRSA. + /// + private static byte[] BuildSignatureAlgorithmIdentifier() + { + byte[] sha256WithRsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 11 }); // SHA256withRSA OID + byte[] nullParam = EncodeAsn1Null(); + return EncodeAsn1Sequence(new[] { sha256WithRsaOid, nullParam }); + } + + /// + /// Signs the CertificationRequestInfo with SHA256withRSA. + /// + private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) + { +#if NET462 || NET472 + using (var sha256 = SHA256.Create()) + { + byte[] hash = sha256.ComputeHash(certificationRequestInfo); + var formatter = new RSAPKCS1SignatureFormatter(rsa); + formatter.SetHashAlgorithm("SHA256"); + return formatter.CreateSignature(hash); + } +#else + return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); +#endif + } + + /// + /// Combines all components into the final CSR structure. + /// + private static byte[] BuildFinalCsr(byte[] certificationRequestInfo, byte[] signatureAlgorithm, byte[] signature) + { + byte[] signatureBitString = EncodeAsn1BitString(signature); + return EncodeAsn1Sequence(new[] { certificationRequestInfo, signatureAlgorithm, signatureBitString }); + } + + /// + /// Converts DER-encoded bytes to PEM format. + /// + private static string ConvertToPem(byte[] derBytes) + { + string base64 = Convert.ToBase64String(derBytes); + var sb = new StringBuilder(); + sb.AppendLine("-----BEGIN CERTIFICATE REQUEST-----"); + + // Split into 64-character lines + for (int i = 0; i < base64.Length; i += 64) + { + int length = Math.Min(64, base64.Length - i); + sb.AppendLine(base64.Substring(i, length)); + } + + sb.AppendLine("-----END CERTIFICATE REQUEST-----"); + return sb.ToString(); + } + + #region ASN.1 Encoding Helpers + + /// + /// Encodes an ASN.1 SEQUENCE. + /// + private static byte[] EncodeAsn1Sequence(byte[][] components) + { + return EncodeAsn1Tag(0x30, ConcatenateByteArrays(components)); + } + + /// + /// Encodes an ASN.1 SEQUENCE without the outer tag (for raw concatenation). + /// + private static byte[] EncodeAsn1SequenceRaw(byte[][] components) + { + return ConcatenateByteArrays(components); + } + + /// + /// Encodes an ASN.1 SET. + /// + private static byte[] EncodeAsn1Set(byte[][] components) + { + return EncodeAsn1Tag(0x31, ConcatenateByteArrays(components)); + } + + /// + /// Encodes an ASN.1 INTEGER. + /// + private static byte[] EncodeAsn1Integer(byte[] value) + { + // Ensure positive integer (prepend 0x00 if high bit is set) + if (value != null && value.Length > 0 && (value[0] & 0x80) != 0) + { + byte[] paddedValue = new byte[value.Length + 1]; + paddedValue[0] = 0x00; + Array.Copy(value, 0, paddedValue, 1, value.Length); + value = paddedValue; + } + return EncodeAsn1Tag(0x02, value ?? new byte[0]); + } + + /// + /// Encodes an ASN.1 INTEGER from an integer value. + /// + private static byte[] EncodeAsn1Integer(int value) + { + if (value == 0) + return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); + + var bytes = new System.Collections.Generic.List(); + int temp = value; + while (temp > 0) + { + bytes.Insert(0, (byte)(temp & 0xFF)); + temp >>= 8; + } + + return EncodeAsn1Integer(bytes.ToArray()); + } + + /// + /// Encodes an ASN.1 BIT STRING. + /// + private static byte[] EncodeAsn1BitString(byte[] value) + { + byte[] bitStringValue = new byte[value.Length + 1]; + bitStringValue[0] = 0x00; // No unused bits + Array.Copy(value, 0, bitStringValue, 1, value.Length); + return EncodeAsn1Tag(0x03, bitStringValue); + } + + /// + /// Encodes an ASN.1 UTF8String. + /// + private static byte[] EncodeAsn1Utf8String(string value) + { + byte[] utf8Bytes = Encoding.UTF8.GetBytes(value); + return EncodeAsn1Tag(0x0C, utf8Bytes); + } + + /// + /// Encodes an ASN.1 PrintableString. + /// + private static byte[] EncodeAsn1PrintableString(string value) + { + byte[] asciiBytes = Encoding.ASCII.GetBytes(value); + return EncodeAsn1Tag(0x13, asciiBytes); + } + + /// + /// Encodes an ASN.1 NULL. + /// + private static byte[] EncodeAsn1Null() + { + return new byte[] { 0x05, 0x00 }; + } + + /// + /// Encodes an ASN.1 OBJECT IDENTIFIER. + /// + private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) + { + if (oid == null || oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var bytes = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + bytes.AddRange(EncodeOidComponent(oid[i])); + } + + return EncodeAsn1Tag(0x06, bytes.ToArray()); + } + + /// + /// Encodes an ASN.1 context-specific tag. + /// + private static byte[] EncodeAsn1ContextSpecific(int tagNumber, byte[] content) + { + byte tag = (byte)(0xA0 | tagNumber); // Context-specific, constructed + return EncodeAsn1Tag(tag, content); + } + + /// + /// Encodes an ASN.1 tag with length and content. + /// + private static byte[] EncodeAsn1Tag(byte tag, byte[] content) + { + byte[] lengthBytes = EncodeAsn1Length(content.Length); + byte[] result = new byte[1 + lengthBytes.Length + content.Length]; + result[0] = tag; + Array.Copy(lengthBytes, 0, result, 1, lengthBytes.Length); + Array.Copy(content, 0, result, 1 + lengthBytes.Length, content.Length); + return result; + } + + /// + /// Encodes ASN.1 length field. + /// + private static byte[] EncodeAsn1Length(int length) + { + if (length < 0x80) + { + return new byte[] { (byte)length }; + } + + var lengthBytes = new System.Collections.Generic.List(); + int temp = length; + while (temp > 0) + { + lengthBytes.Insert(0, (byte)(temp & 0xFF)); + temp >>= 8; + } + + byte[] result = new byte[lengthBytes.Count + 1]; + result[0] = (byte)(0x80 | lengthBytes.Count); + lengthBytes.CopyTo(result, 1); + return result; + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private static byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + + /// + /// Concatenates multiple byte arrays. + /// + private static byte[] ConcatenateByteArrays(byte[][] arrays) + { + int totalLength = 0; + foreach (byte[] array in arrays) + { + totalLength += array.Length; + } + + byte[] result = new byte[totalLength]; + int offset = 0; + foreach (byte[] array in arrays) + { + Array.Copy(array, 0, result, offset, array.Length); + offset += array.Length; + } + + return result; + } + + #endregion } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 4bf53c7a42..830082c140 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Net; using System.Threading.Tasks; using Microsoft.Identity.Client; @@ -131,6 +132,24 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs } } - // TODO: Create CSR generation unit tests + [TestMethod] + public void TestCsrGeneration() + { + var cuid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "test-vmss-id-67890" + }; + + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; + + // Generate CSR + var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + + // Output the generated CSR for analysis + System.Console.WriteLine("Generated CSR:"); + System.Console.WriteLine(csrRequest.Pem); + } } } From 762ccdfbcce6ec145375354b9d588742b3b3ba2a Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 17:00:50 -0400 Subject: [PATCH 03/41] first pass at improved unit tests --- .../ManagedIdentityTests/ImdsV2Tests.cs | 486 ++++++++++++++++++ 1 file changed, 486 insertions(+) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 830082c140..f5c8c2ed0a 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -150,6 +150,492 @@ public void TestCsrGeneration() // Output the generated CSR for analysis System.Console.WriteLine("Generated CSR:"); System.Console.WriteLine(csrRequest.Pem); + + // Validate the CSR contents + ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + } + + [TestMethod] + public void TestCsrGeneration_InvalidClientId() + { + var cuid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "test-vmss-id-67890" + }; + + string tenantId = "87654321-4321-4321-4321-210987654321"; + + // Test with null client ID + Assert.ThrowsException(() => + CsrRequest.Generate(null, tenantId, cuid)); + + // Test with empty client ID + Assert.ThrowsException(() => + CsrRequest.Generate("", tenantId, cuid)); + + // Test with whitespace client ID + Assert.ThrowsException(() => + CsrRequest.Generate(" ", tenantId, cuid)); + } + + [TestMethod] + public void TestCsrGeneration_InvalidTenantId() + { + var cuid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "test-vmss-id-67890" + }; + + string clientId = "12345678-1234-1234-1234-123456789012"; + + // Test with null tenant ID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, null, cuid)); + + // Test with empty tenant ID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, "", cuid)); + + // Test with whitespace tenant ID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, " ", cuid)); + } + + [TestMethod] + public void TestCsrGeneration_InvalidCuid() + { + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; + + // Test with null CUID + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, null)); + + // Test with null VMID + var cuidWithNullVmid = new CuidInfo + { + Vmid = null, + Vmssid = "test-vmss-id-67890" + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithNullVmid)); + + // Test with empty VMID + var cuidWithEmptyVmid = new CuidInfo + { + Vmid = "", + Vmssid = "test-vmss-id-67890" + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmid)); + + // Test with null VMSSID + var cuidWithNullVmssid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = null + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithNullVmssid)); + + // Test with empty VMSSID + var cuidWithEmptyVmssid = new CuidInfo + { + Vmid = "test-vm-id-12345", + Vmssid = "" + }; + Assert.ThrowsException(() => + CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmssid)); + } + + [TestMethod] + public void TestCsrGeneration_MalformedPem() + { + // Test parsing malformed PEM with invalid Base64 characters + string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; + + Assert.ThrowsException(() => + ParseCsrFromPem(malformedPem)); + + // Test with wrong headers + string wrongHeaders = "-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE-----"; + + Assert.ThrowsException(() => + ParseCsrFromPem(wrongHeaders)); + } + + #region CSR Validation Helper Methods + + /// + /// Validates the content of a CSR PEM string against expected values. + /// + private void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) + { + // Parse the CSR from PEM format + var csrData = ParseCsrFromPem(pemCsr); + + // Parse the PKCS#10 structure + var csrInfo = ParsePkcs10Structure(csrData); + + // Validate subject name + ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); + + // Validate public key + ValidatePublicKey(csrInfo.PublicKey); + + // Validate CUID attribute + ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); + + // Validate signature algorithm + ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); + } + + /// + /// Parses a PEM-formatted CSR and returns the DER bytes. + /// + private byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + + /// + /// Represents parsed PKCS#10 CSR information. + /// + private class CsrInfo + { + public byte[] Subject { get; set; } + public byte[] PublicKey { get; set; } + public byte[] Attributes { get; set; } + public byte[] SignatureAlgorithm { get; set; } + } + + /// + /// Parses the PKCS#10 ASN.1 structure and extracts key components. + /// + private CsrInfo ParsePkcs10Structure(byte[] derBytes) + { + int offset = 0; + + // Parse outer SEQUENCE (CertificationRequest) + var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); + + // Reset offset to parse the CertificationRequestInfo within the outer sequence + int infoOffset = 0; + var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); + + // Parse version (should be 0) + int versionOffset = 0; + var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); + if (version.Length != 1 || version[0] != 0x00) + throw new ArgumentException("Invalid CSR version"); + + // Parse subject + var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse SubjectPublicKeyInfo + var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse attributes (context-specific [0]) + var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); + + return new CsrInfo + { + Subject = subject, + PublicKey = publicKey, + Attributes = attributes, + SignatureAlgorithm = new byte[0] // Simplified for this test + }; + } + + /// + /// Parses an ASN.1 tag and returns its content. + /// + private byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data"); + + // Check tag (if expectedTag is -1, accept any tag) + if (expectedTag != 255 && data[offset] != expectedTag) + throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); + + offset++; + + // Parse length + int length = ParseAsn1Length(data, ref offset); + + // Extract content + if (offset + length > data.Length) + throw new ArgumentException("Invalid ASN.1 length"); + + byte[] content = new byte[length]; + Array.Copy(data, offset, content, 0, length); + offset += length; + + return content; + } + + /// + /// Parses ASN.1 length encoding. + /// + private int ParseAsn1Length(byte[] data, ref int offset) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data in length"); + + byte firstByte = data[offset++]; + + // Short form + if ((firstByte & 0x80) == 0) + return firstByte; + + // Long form + int lengthBytes = firstByte & 0x7F; + if (lengthBytes == 0) + throw new ArgumentException("Indefinite length not supported"); + + if (offset + lengthBytes > data.Length) + throw new ArgumentException("Invalid length encoding"); + + int length = 0; + for (int i = 0; i < lengthBytes; i++) + { + length = (length << 8) | data[offset++]; + } + + return length; + } + + /// + /// Validates the subject name contains the expected client ID and tenant ID. + /// + private void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) + { + // Subject is already a SEQUENCE of RDNs + int offset = 0; + bool foundClientId = false; + bool foundTenantId = false; + + // Parse each RDN (Relative Distinguished Name) directly from subjectBytes + while (offset < subjectBytes.Length) + { + var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET + + int rdnOffset = 0; + var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE + + // Parse OID and value + int attrOffset = 0; + var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID + var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type + + string stringValue = System.Text.Encoding.UTF8.GetString(value); + + // Check for CN (commonName) OID: 2.5.4.3 + if (IsOid(oid, new int[] { 2, 5, 4, 3 })) + { + Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); + foundClientId = true; + } + // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 + else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) + { + Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); + foundTenantId = true; + } + } + + Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); + Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); + } + + /// + /// Validates the public key is a valid RSA key. + /// + private void ValidatePublicKey(byte[] publicKeyBytes) + { + // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content + int offset = 0; + + // Parse algorithm identifier + var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); + + // Parse public key bit string + var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); + + // Validate algorithm is RSA (1.2.840.113549.1.1.1) + int algOffset = 0; + var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); + Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), + "Public key algorithm is not RSA"); + + // Skip the unused bits byte in bit string + if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) + throw new ArgumentException("Invalid public key bit string"); + + // Parse RSA public key (skip unused bits byte) + byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; + Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); + + int rsaOffset = 0; + var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); + + rsaOffset = 0; + var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + + // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) + Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, + $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); + + // Validate exponent (commonly 65537 = 0x010001) + Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); + } + + /// + /// Validates the CUID attribute contains the expected VM and VMSS IDs. + /// + private void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) + { + // Attributes is a SET of attributes + // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) + + int offset = 0; + bool foundCuid = false; + + // Parse each attribute in the SET + while (offset < attributesBytes.Length) + { + var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); + + int attrOffset = 0; + var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); + var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values + + // Check for challengePassword OID: 1.2.840.113549.1.9.7 + if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + { + // Parse the value from the SET (should be one value) + int valueOffset = 0; + var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type + + string cuidValue = System.Text.Encoding.ASCII.GetString(value); + string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); + foundCuid = true; + break; + } + } + + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); + } + + /// + /// Validates the signature algorithm is SHA256withRSA. + /// + private void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) + { + // For this test, we'll just verify that signature algorithm exists + // Full validation would require parsing the outer CSR structure + // which is more complex for this unit test scenario + Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); } + + /// + /// Checks if the given OID bytes match the expected OID components. + /// + private bool IsOid(byte[] oidBytes, int[] expectedOid) + { + if (expectedOid.Length < 2) + return false; + + var expectedBytes = EncodeOid(expectedOid); + + if (oidBytes.Length != expectedBytes.Length) + return false; + + for (int i = 0; i < oidBytes.Length; i++) + { + if (oidBytes[i] != expectedBytes[i]) + return false; + } + + return true; + } + + /// + /// Encodes an OID from integer components to bytes (simplified version). + /// + private byte[] EncodeOid(int[] oid) + { + if (oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var result = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + result.AddRange(EncodeOidComponent(oid[i])); + } + + return result.ToArray(); + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + + #endregion } } From 4ea6c09af1ef8dc267cb6547522f175a838775d5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 6 Aug 2025 18:06:47 -0400 Subject: [PATCH 04/41] Finished improving unit tests --- .../ImdsV2ManagedIdentitySource.cs | 2 +- .../ManagedIdentityTests/CsrValidator.cs | 419 +++++++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 485 ++---------------- 3 files changed, 455 insertions(+), 451 deletions(-) create mode 100644 tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 6fef52849b..08bbc9caf4 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -91,7 +91,7 @@ public static async Task GetCsrMetadataAsync( } } - if (!probeMode && !ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) + if (!ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) { return null; } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs new file mode 100644 index 0000000000..b4a4e99dd5 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -0,0 +1,419 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + /// + /// Test helper to expose CsrValidator methods for testing malformed PEM. + /// + internal static class TestCsrValidator + { + public static byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + } + + /// + /// Helper class for validating Certificate Signing Request (CSR) content and structure. + /// + internal static class CsrValidator + { + /// + /// Validates the content of a CSR PEM string against expected values. + /// + public static void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) + { + // Parse the CSR from PEM format + var csrData = ParseCsrFromPem(pemCsr); + + // Parse the PKCS#10 structure + var csrInfo = ParsePkcs10Structure(csrData); + + // Validate subject name + ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); + + // Validate public key + ValidatePublicKey(csrInfo.PublicKey); + + // Validate CUID attribute + ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); + + // Validate signature algorithm + ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); + } + + /// + /// Parses a PEM-formatted CSR and returns the DER bytes. + /// + private static byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + + /// + /// Represents parsed PKCS#10 CSR information. + /// + private class CsrInfo + { + public byte[] Subject { get; set; } + public byte[] PublicKey { get; set; } + public byte[] Attributes { get; set; } + public byte[] SignatureAlgorithm { get; set; } + } + + /// + /// Parses the PKCS#10 ASN.1 structure and extracts key components. + /// + private static CsrInfo ParsePkcs10Structure(byte[] derBytes) + { + int offset = 0; + + // Parse outer SEQUENCE (CertificationRequest) + var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); + + // Reset offset to parse the CertificationRequestInfo within the outer sequence + int infoOffset = 0; + var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); + + // Parse version (should be 0) + int versionOffset = 0; + var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); + if (version.Length != 1 || version[0] != 0x00) + throw new ArgumentException("Invalid CSR version"); + + // Parse subject + var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse SubjectPublicKeyInfo + var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse attributes (context-specific [0]) + var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); + + return new CsrInfo + { + Subject = subject, + PublicKey = publicKey, + Attributes = attributes, + SignatureAlgorithm = new byte[0] // Simplified for this test + }; + } + + /// + /// Parses an ASN.1 tag and returns its content. + /// + private static byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data"); + + // Check tag (if expectedTag is -1, accept any tag) + if (expectedTag != 255 && data[offset] != expectedTag) + throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); + + offset++; + + // Parse length + int length = ParseAsn1Length(data, ref offset); + + // Extract content + if (offset + length > data.Length) + throw new ArgumentException("Invalid ASN.1 length"); + + byte[] content = new byte[length]; + Array.Copy(data, offset, content, 0, length); + offset += length; + + return content; + } + + /// + /// Parses ASN.1 length encoding. + /// + private static int ParseAsn1Length(byte[] data, ref int offset) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data in length"); + + byte firstByte = data[offset++]; + + // Short form + if ((firstByte & 0x80) == 0) + return firstByte; + + // Long form + int lengthBytes = firstByte & 0x7F; + if (lengthBytes == 0) + throw new ArgumentException("Indefinite length not supported"); + + if (offset + lengthBytes > data.Length) + throw new ArgumentException("Invalid length encoding"); + + int length = 0; + for (int i = 0; i < lengthBytes; i++) + { + length = (length << 8) | data[offset++]; + } + + return length; + } + + /// + /// Validates the subject name contains the expected client ID and tenant ID. + /// + private static void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) + { + // Subject is already a SEQUENCE of RDNs + int offset = 0; + bool foundClientId = false; + bool foundTenantId = false; + + // Parse each RDN (Relative Distinguished Name) directly from subjectBytes + while (offset < subjectBytes.Length) + { + var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET + + int rdnOffset = 0; + var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE + + // Parse OID and value + int attrOffset = 0; + var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID + var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type + + string stringValue = System.Text.Encoding.UTF8.GetString(value); + + // Check for CN (commonName) OID: 2.5.4.3 + if (IsOid(oid, new int[] { 2, 5, 4, 3 })) + { + Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); + foundClientId = true; + } + // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 + else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) + { + Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); + foundTenantId = true; + } + } + + Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); + Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); + } + + /// + /// Validates the public key is a valid RSA key. + /// + private static void ValidatePublicKey(byte[] publicKeyBytes) + { + // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content + int offset = 0; + + // Parse algorithm identifier + var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); + + // Parse public key bit string + var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); + + // Validate algorithm is RSA (1.2.840.113549.1.1.1) + int algOffset = 0; + var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); + Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), + "Public key algorithm is not RSA"); + + // Skip the unused bits byte in bit string + if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) + throw new ArgumentException("Invalid public key bit string"); + + // Parse RSA public key (skip unused bits byte) + byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; + Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); + + int rsaOffset = 0; + var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); + + rsaOffset = 0; + var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + + // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) + Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, + $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); + + // Validate exponent (commonly 65537 = 0x010001) + Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); + } + + /// + /// Validates the CUID attribute contains the expected VM and VMSS IDs. + /// + private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) + { + // Attributes is a SET of attributes + // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) + + int offset = 0; + bool foundCuid = false; + + // Parse each attribute in the SET + while (offset < attributesBytes.Length) + { + var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); + + int attrOffset = 0; + var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); + var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values + + // Check for challengePassword OID: 1.2.840.113549.1.9.7 + if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + { + // Parse the value from the SET (should be one value) + int valueOffset = 0; + var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type + + string cuidValue = System.Text.Encoding.ASCII.GetString(value); + string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); + foundCuid = true; + break; + } + } + + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); + } + + /// + /// Validates the signature algorithm is SHA256withRSA. + /// + private static void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) + { + // For this test, we'll just verify that signature algorithm exists + // Full validation would require parsing the outer CSR structure + // which is more complex for this unit test scenario + Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); + } + + /// + /// Checks if the given OID bytes match the expected OID components. + /// + private static bool IsOid(byte[] oidBytes, int[] expectedOid) + { + if (expectedOid.Length < 2) + return false; + + var expectedBytes = EncodeOid(expectedOid); + + if (oidBytes.Length != expectedBytes.Length) + return false; + + for (int i = 0; i < oidBytes.Length; i++) + { + if (oidBytes[i] != expectedBytes[i]) + return false; + } + + return true; + } + + /// + /// Encodes an OID from integer components to bytes (simplified version). + /// + private static byte[] EncodeOid(int[] oid) + { + if (oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var result = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + result.AddRange(EncodeOidComponent(oid[i])); + } + + return result.ToArray(); + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private static byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index f5c8c2ed0a..b9ee9b2268 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -151,12 +151,18 @@ public void TestCsrGeneration() System.Console.WriteLine("Generated CSR:"); System.Console.WriteLine(csrRequest.Pem); - // Validate the CSR contents - ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + // Validate the CSR contents using the helper + CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); } - [TestMethod] - public void TestCsrGeneration_InvalidClientId() + [DataTestMethod] + [DataRow(null, "87654321-4321-4321-4321-210987654321", DisplayName = "Null ClientId")] + [DataRow("", "87654321-4321-4321-4321-210987654321", DisplayName = "Empty ClientId")] + [DataRow(" ", "87654321-4321-4321-4321-210987654321", DisplayName = "Whitespace ClientId")] + [DataRow("12345678-1234-1234-1234-123456789012", null, DisplayName = "Null TenantId")] + [DataRow("12345678-1234-1234-1234-123456789012", "", DisplayName = "Empty TenantId")] + [DataRow("12345678-1234-1234-1234-123456789012", " ", DisplayName = "Whitespace TenantId")] + public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) { var cuid = new CuidInfo { @@ -164,47 +170,12 @@ public void TestCsrGeneration_InvalidClientId() Vmssid = "test-vmss-id-67890" }; - string tenantId = "87654321-4321-4321-4321-210987654321"; - - // Test with null client ID - Assert.ThrowsException(() => - CsrRequest.Generate(null, tenantId, cuid)); - - // Test with empty client ID - Assert.ThrowsException(() => - CsrRequest.Generate("", tenantId, cuid)); - - // Test with whitespace client ID Assert.ThrowsException(() => - CsrRequest.Generate(" ", tenantId, cuid)); + CsrRequest.Generate(clientId, tenantId, cuid)); } [TestMethod] - public void TestCsrGeneration_InvalidTenantId() - { - var cuid = new CuidInfo - { - Vmid = "test-vm-id-12345", - Vmssid = "test-vmss-id-67890" - }; - - string clientId = "12345678-1234-1234-1234-123456789012"; - - // Test with null tenant ID - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, null, cuid)); - - // Test with empty tenant ID - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, "", cuid)); - - // Test with whitespace tenant ID - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, " ", cuid)); - } - - [TestMethod] - public void TestCsrGeneration_InvalidCuid() + public void TestCsrGeneration_NullCuid() { string clientId = "12345678-1234-1234-1234-123456789012"; string tenantId = "87654321-4321-4321-4321-210987654321"; @@ -212,430 +183,44 @@ public void TestCsrGeneration_InvalidCuid() // Test with null CUID Assert.ThrowsException(() => CsrRequest.Generate(clientId, tenantId, null)); + } - // Test with null VMID - var cuidWithNullVmid = new CuidInfo - { - Vmid = null, - Vmssid = "test-vmss-id-67890" - }; - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithNullVmid)); - - // Test with empty VMID - var cuidWithEmptyVmid = new CuidInfo - { - Vmid = "", - Vmssid = "test-vmss-id-67890" - }; - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmid)); + [DataTestMethod] + [DataRow(null, "test-vmss-id-67890", DisplayName = "Null VMID")] + [DataRow("", "test-vmss-id-67890", DisplayName = "Empty VMID")] + [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] + [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] + public void TestCsrGeneration_InvalidCuidProperties(string vmid, string vmssid) + { + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; - // Test with null VMSSID - var cuidWithNullVmssid = new CuidInfo + var cuid = new CuidInfo { - Vmid = "test-vm-id-12345", - Vmssid = null + Vmid = vmid, + Vmssid = vmssid }; - Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithNullVmssid)); - // Test with empty VMSSID - var cuidWithEmptyVmssid = new CuidInfo - { - Vmid = "test-vm-id-12345", - Vmssid = "" - }; Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuidWithEmptyVmssid)); + CsrRequest.Generate(clientId, tenantId, cuid)); } [TestMethod] - public void TestCsrGeneration_MalformedPem() + public void TestCsrGeneration_MalformedPem_FormatException() { - // Test parsing malformed PEM with invalid Base64 characters string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; - Assert.ThrowsException(() => - ParseCsrFromPem(malformedPem)); - - // Test with wrong headers - string wrongHeaders = "-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE-----"; - - Assert.ThrowsException(() => - ParseCsrFromPem(wrongHeaders)); - } - - #region CSR Validation Helper Methods - - /// - /// Validates the content of a CSR PEM string against expected values. - /// - private void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) - { - // Parse the CSR from PEM format - var csrData = ParseCsrFromPem(pemCsr); - - // Parse the PKCS#10 structure - var csrInfo = ParsePkcs10Structure(csrData); - - // Validate subject name - ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); - - // Validate public key - ValidatePublicKey(csrInfo.PublicKey); - - // Validate CUID attribute - ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); - - // Validate signature algorithm - ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); - } - - /// - /// Parses a PEM-formatted CSR and returns the DER bytes. - /// - private byte[] ParseCsrFromPem(string pemCsr) - { - if (string.IsNullOrWhiteSpace(pemCsr)) - throw new ArgumentException("PEM CSR cannot be null or empty"); - - const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; - const string endMarker = "-----END CERTIFICATE REQUEST-----"; - - if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) - throw new ArgumentException("Invalid PEM format - missing CSR headers"); - - int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; - int endIndex = pemCsr.IndexOf(endMarker); - - if (beginIndex >= endIndex) - throw new ArgumentException("Invalid PEM format - malformed headers"); - - string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) - .Replace("\r", "").Replace("\n", "").Replace(" ", ""); - - try - { - return Convert.FromBase64String(base64Content); - } - catch (FormatException) - { - throw new FormatException("Invalid Base64 content in PEM CSR"); - } - } - - /// - /// Represents parsed PKCS#10 CSR information. - /// - private class CsrInfo - { - public byte[] Subject { get; set; } - public byte[] PublicKey { get; set; } - public byte[] Attributes { get; set; } - public byte[] SignatureAlgorithm { get; set; } + TestCsrValidator.ParseCsrFromPem(malformedPem)); } - /// - /// Parses the PKCS#10 ASN.1 structure and extracts key components. - /// - private CsrInfo ParsePkcs10Structure(byte[] derBytes) + [DataTestMethod] + [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----", DisplayName = "Wrong Headers")] + [DataRow("", DisplayName = "Empty PEM")] + [DataRow(null, DisplayName = "Null PEM")] + public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) { - int offset = 0; - - // Parse outer SEQUENCE (CertificationRequest) - var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); - - // Reset offset to parse the CertificationRequestInfo within the outer sequence - int infoOffset = 0; - var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); - - // Parse version (should be 0) - int versionOffset = 0; - var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); - if (version.Length != 1 || version[0] != 0x00) - throw new ArgumentException("Invalid CSR version"); - - // Parse subject - var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse SubjectPublicKeyInfo - var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse attributes (context-specific [0]) - var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); - - return new CsrInfo - { - Subject = subject, - PublicKey = publicKey, - Attributes = attributes, - SignatureAlgorithm = new byte[0] // Simplified for this test - }; - } - - /// - /// Parses an ASN.1 tag and returns its content. - /// - private byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data"); - - // Check tag (if expectedTag is -1, accept any tag) - if (expectedTag != 255 && data[offset] != expectedTag) - throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); - - offset++; - - // Parse length - int length = ParseAsn1Length(data, ref offset); - - // Extract content - if (offset + length > data.Length) - throw new ArgumentException("Invalid ASN.1 length"); - - byte[] content = new byte[length]; - Array.Copy(data, offset, content, 0, length); - offset += length; - - return content; - } - - /// - /// Parses ASN.1 length encoding. - /// - private int ParseAsn1Length(byte[] data, ref int offset) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data in length"); - - byte firstByte = data[offset++]; - - // Short form - if ((firstByte & 0x80) == 0) - return firstByte; - - // Long form - int lengthBytes = firstByte & 0x7F; - if (lengthBytes == 0) - throw new ArgumentException("Indefinite length not supported"); - - if (offset + lengthBytes > data.Length) - throw new ArgumentException("Invalid length encoding"); - - int length = 0; - for (int i = 0; i < lengthBytes; i++) - { - length = (length << 8) | data[offset++]; - } - - return length; - } - - /// - /// Validates the subject name contains the expected client ID and tenant ID. - /// - private void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) - { - // Subject is already a SEQUENCE of RDNs - int offset = 0; - bool foundClientId = false; - bool foundTenantId = false; - - // Parse each RDN (Relative Distinguished Name) directly from subjectBytes - while (offset < subjectBytes.Length) - { - var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET - - int rdnOffset = 0; - var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE - - // Parse OID and value - int attrOffset = 0; - var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID - var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type - - string stringValue = System.Text.Encoding.UTF8.GetString(value); - - // Check for CN (commonName) OID: 2.5.4.3 - if (IsOid(oid, new int[] { 2, 5, 4, 3 })) - { - Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); - foundClientId = true; - } - // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 - else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) - { - Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); - foundTenantId = true; - } - } - - Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); - Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); - } - - /// - /// Validates the public key is a valid RSA key. - /// - private void ValidatePublicKey(byte[] publicKeyBytes) - { - // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content - int offset = 0; - - // Parse algorithm identifier - var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); - - // Parse public key bit string - var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); - - // Validate algorithm is RSA (1.2.840.113549.1.1.1) - int algOffset = 0; - var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); - Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), - "Public key algorithm is not RSA"); - - // Skip the unused bits byte in bit string - if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) - throw new ArgumentException("Invalid public key bit string"); - - // Parse RSA public key (skip unused bits byte) - byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; - Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); - - int rsaOffset = 0; - var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); - - rsaOffset = 0; - var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - - // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) - Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, - $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); - - // Validate exponent (commonly 65537 = 0x010001) - Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); - } - - /// - /// Validates the CUID attribute contains the expected VM and VMSS IDs. - /// - private void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) - { - // Attributes is a SET of attributes - // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) - - int offset = 0; - bool foundCuid = false; - - // Parse each attribute in the SET - while (offset < attributesBytes.Length) - { - var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); - - int attrOffset = 0; - var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); - var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values - - // Check for challengePassword OID: 1.2.840.113549.1.9.7 - if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) - { - // Parse the value from the SET (should be one value) - int valueOffset = 0; - var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type - - string cuidValue = System.Text.Encoding.ASCII.GetString(value); - string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; - - Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); - foundCuid = true; - break; - } - } - - Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); - } - - /// - /// Validates the signature algorithm is SHA256withRSA. - /// - private void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) - { - // For this test, we'll just verify that signature algorithm exists - // Full validation would require parsing the outer CSR structure - // which is more complex for this unit test scenario - Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); - } - - /// - /// Checks if the given OID bytes match the expected OID components. - /// - private bool IsOid(byte[] oidBytes, int[] expectedOid) - { - if (expectedOid.Length < 2) - return false; - - var expectedBytes = EncodeOid(expectedOid); - - if (oidBytes.Length != expectedBytes.Length) - return false; - - for (int i = 0; i < oidBytes.Length; i++) - { - if (oidBytes[i] != expectedBytes[i]) - return false; - } - - return true; - } - - /// - /// Encodes an OID from integer components to bytes (simplified version). - /// - private byte[] EncodeOid(int[] oid) - { - if (oid.Length < 2) - throw new ArgumentException("OID must have at least 2 components"); - - var result = new System.Collections.Generic.List(); - - // First two components are encoded as (first * 40 + second) - result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); - - // Remaining components - for (int i = 2; i < oid.Length; i++) - { - result.AddRange(EncodeOidComponent(oid[i])); - } - - return result.ToArray(); - } - - /// - /// Encodes a single OID component using variable-length encoding. - /// - private byte[] EncodeOidComponent(int value) - { - if (value == 0) - return new byte[] { 0x00 }; - - var bytes = new System.Collections.Generic.List(); - int temp = value; - - bytes.Insert(0, (byte)(temp & 0x7F)); - temp >>= 7; - - while (temp > 0) - { - bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); - temp >>= 7; - } - - return bytes.ToArray(); + Assert.ThrowsException(() => + TestCsrValidator.ParseCsrFromPem(malformedPem)); } - - #endregion } } From 009f948a9b7267d00a0183439f0912cd24f338ad Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 12:02:32 -0400 Subject: [PATCH 05/41] Updates to CUID --- .../ManagedIdentity/CsrMetadata.cs | 3 +- .../ManagedIdentity/CsrRequest.cs | 8 ++--- .../ManagedIdentityTests/CsrValidator.cs | 18 +++++++++-- .../ManagedIdentityTests/ImdsV2Tests.cs | 32 +++++++++++++++---- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs index 04a9e06baf..a831d02c7a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs @@ -57,13 +57,12 @@ public CsrMetadata() { } /// Validates a JSON decoded CsrMetadata instance. /// /// The CsrMetadata object. - /// false if any field is null. + /// false if any required field is null. Note: Vmid is required, Vmssid is optional. public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null || csrMetadata.Cuid == null || string.IsNullOrEmpty(csrMetadata.Cuid.Vmid) || - string.IsNullOrEmpty(csrMetadata.Cuid.Vmssid) || string.IsNullOrEmpty(csrMetadata.ClientId) || string.IsNullOrEmpty(csrMetadata.TenantId) || string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs index aa692a5b0f..2604a8c31e 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -4,6 +4,7 @@ using System; using System.Security.Cryptography; using System.Text; +using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -21,7 +22,7 @@ public CsrRequest(string pem) /// /// Managed Identity client_id. /// AAD tenant_id. - /// CuidInfo object containing VMID and VMSSID. + /// CuidInfo object containing required VMID and optional VMSSID. /// CsrRequest containing the PEM CSR. public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cuid) { @@ -33,8 +34,6 @@ public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cui throw new ArgumentNullException(nameof(cuid)); if (string.IsNullOrWhiteSpace(cuid.Vmid)) throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); - if (string.IsNullOrWhiteSpace(cuid.Vmssid)) - throw new ArgumentException("cuid.Vmssid must not be null or empty.", nameof(cuid.Vmssid)); string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); return new CsrRequest(pemCsr); @@ -156,8 +155,9 @@ private static byte[] BuildAttributes(CuidInfo cuid) var attributes = new System.Collections.Generic.List(); // CUID attribute (OID 1.2.840.113549.1.9.7) + // Serialize CuidInfo as JSON object string using existing JSON serialization byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); - string cuidValue = $"{cuid.Vmid}:{cuid.Vmssid}"; + string cuidValue = JsonHelper.SerializeToJson(cuid); byte[] cuidData = EncodeAsn1PrintableString(cuidValue); byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index b4a4e99dd5..671700c100 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -3,6 +3,7 @@ using System; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.Utils; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests @@ -300,7 +301,8 @@ private static void ValidatePublicKey(byte[] publicKeyBytes) } /// - /// Validates the CUID attribute contains the expected VM and VMSS IDs. + /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. + /// Note: Vmid is required, Vmssid is optional and will be omitted if null/empty. /// private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) { @@ -327,9 +329,11 @@ private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expec var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type string cuidValue = System.Text.Encoding.ASCII.GetString(value); - string expectedCuidValue = $"{expectedCuid.Vmid}:{expectedCuid.Vmssid}"; - Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute value does not match expected"); + // Build expected CUID value as JSON + string expectedCuidValue = BuildExpectedCuidJson(expectedCuid); + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute JSON value does not match expected"); foundCuid = true; break; } @@ -338,6 +342,14 @@ private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expec Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); } + /// + /// Builds the expected CUID JSON string for validation using JsonHelper. + /// + private static string BuildExpectedCuidJson(CuidInfo expectedCuid) + { + return JsonHelper.SerializeToJson(expectedCuid); + } + /// /// Validates the signature algorithm is SHA256withRSA. /// diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index b9ee9b2268..a3b9fd7df8 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -147,10 +147,6 @@ public void TestCsrGeneration() // Generate CSR var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); - // Output the generated CSR for analysis - System.Console.WriteLine("Generated CSR:"); - System.Console.WriteLine(csrRequest.Pem); - // Validate the CSR contents using the helper CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); } @@ -188,9 +184,7 @@ public void TestCsrGeneration_NullCuid() [DataTestMethod] [DataRow(null, "test-vmss-id-67890", DisplayName = "Null VMID")] [DataRow("", "test-vmss-id-67890", DisplayName = "Empty VMID")] - [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] - [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] - public void TestCsrGeneration_InvalidCuidProperties(string vmid, string vmssid) + public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) { string clientId = "12345678-1234-1234-1234-123456789012"; string tenantId = "87654321-4321-4321-4321-210987654321"; @@ -201,10 +195,34 @@ public void TestCsrGeneration_InvalidCuidProperties(string vmid, string vmssid) Vmssid = vmssid }; + // Should throw ArgumentException since Vmid is required Assert.ThrowsException(() => CsrRequest.Generate(clientId, tenantId, cuid)); } + [DataTestMethod] + [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] + [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] + public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) + { + string clientId = "12345678-1234-1234-1234-123456789012"; + string tenantId = "87654321-4321-4321-4321-210987654321"; + + var cuid = new CuidInfo + { + Vmid = vmid, + Vmssid = vmssid + }; + + // Should succeed since Vmssid is optional (Vmid is provided and valid) + var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + Assert.IsNotNull(csrRequest); + Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); + + // Validate the CSR contents - this should handle null/empty VMSSID gracefully + CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + } + [TestMethod] public void TestCsrGeneration_MalformedPem_FormatException() { From 21d4ef3663cad4c52ec024da257e7ca276309a5f Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 13:57:26 -0400 Subject: [PATCH 06/41] Unit test improvements --- .../TestConstants.cs | 2 + .../ManagedIdentityTests/ImdsV2Tests.cs | 56 +++++++------------ 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 3d89cc1bbe..d4a63354c0 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -154,6 +154,8 @@ public static HashSet s_scope public const string IdentityProvider = "my-idp"; public const string Name = "First Last"; public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; + public const string Vmid = "test-vm-id"; + public const string Vmssid = "test-vmss-id"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index a3b9fd7df8..2e09ace05e 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -137,33 +137,28 @@ public void TestCsrGeneration() { var cuid = new CuidInfo { - Vmid = "test-vm-id-12345", - Vmssid = "test-vmss-id-67890" + Vmid = TestConstants.Vmid, + Vmssid = TestConstants.Vmssid }; - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - // Generate CSR - var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); // Validate the CSR contents using the helper - CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [DataTestMethod] - [DataRow(null, "87654321-4321-4321-4321-210987654321", DisplayName = "Null ClientId")] - [DataRow("", "87654321-4321-4321-4321-210987654321", DisplayName = "Empty ClientId")] - [DataRow(" ", "87654321-4321-4321-4321-210987654321", DisplayName = "Whitespace ClientId")] - [DataRow("12345678-1234-1234-1234-123456789012", null, DisplayName = "Null TenantId")] - [DataRow("12345678-1234-1234-1234-123456789012", "", DisplayName = "Empty TenantId")] - [DataRow("12345678-1234-1234-1234-123456789012", " ", DisplayName = "Whitespace TenantId")] + [DataRow(null, TestConstants.TenantId)] + [DataRow("", TestConstants.TenantId)] + [DataRow(TestConstants.ClientId, null)] + [DataRow(TestConstants.ClientId, "")] public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) { var cuid = new CuidInfo { - Vmid = "test-vm-id-12345", - Vmssid = "test-vmss-id-67890" + Vmid = TestConstants.Vmid, + Vmssid = TestConstants.Vmssid }; Assert.ThrowsException(() => @@ -173,22 +168,16 @@ public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId [TestMethod] public void TestCsrGeneration_NullCuid() { - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - // Test with null CUID Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, null)); + CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); } [DataTestMethod] - [DataRow(null, "test-vmss-id-67890", DisplayName = "Null VMID")] - [DataRow("", "test-vmss-id-67890", DisplayName = "Empty VMID")] + [DataRow(null, TestConstants.Vmssid)] + [DataRow("", TestConstants.Vmssid)] public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) { - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - var cuid = new CuidInfo { Vmid = vmid, @@ -197,17 +186,14 @@ public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) // Should throw ArgumentException since Vmid is required Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuid)); + CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); } [DataTestMethod] - [DataRow("test-vm-id-12345", null, DisplayName = "Null VMSSID")] - [DataRow("test-vm-id-12345", "", DisplayName = "Empty VMSSID")] + [DataRow(TestConstants.Vmid, null)] + [DataRow(TestConstants.Vmid, "")] public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) { - string clientId = "12345678-1234-1234-1234-123456789012"; - string tenantId = "87654321-4321-4321-4321-210987654321"; - var cuid = new CuidInfo { Vmid = vmid, @@ -215,12 +201,12 @@ public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) }; // Should succeed since Vmssid is optional (Vmid is provided and valid) - var csrRequest = CsrRequest.Generate(clientId, tenantId, cuid); + var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); Assert.IsNotNull(csrRequest); Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); // Validate the CSR contents - this should handle null/empty VMSSID gracefully - CsrValidator.ValidateCsrContent(csrRequest.Pem, clientId, tenantId, cuid); + CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [TestMethod] @@ -232,9 +218,9 @@ public void TestCsrGeneration_MalformedPem_FormatException() } [DataTestMethod] - [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----", DisplayName = "Wrong Headers")] - [DataRow("", DisplayName = "Empty PEM")] - [DataRow(null, DisplayName = "Null PEM")] + [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----")] + [DataRow("")] + [DataRow(null)] public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) { Assert.ThrowsException(() => From cd013a33c09d4c81b13206edd625d9941c593855 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 14:15:05 -0400 Subject: [PATCH 07/41] Implemented Feedback --- ...e.cs => ClientCredentialRequestResponse.cs} | 18 +++++++++--------- .../ManagedIdentity/CsrRequest.cs | 14 +++++++------- .../ImdsV2ManagedIdentitySource.cs | 16 ++++++++-------- .../Microsoft.Identity.Client.csproj | 5 +++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 12 ++++++------ 5 files changed, 35 insertions(+), 30 deletions(-) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{CsrRequestResponse.cs => ClientCredentialRequestResponse.cs} (60%) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs similarity index 60% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index 10274e48ba..924dd75ad1 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -13,7 +13,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity /// /// Represents the response for a Managed Identity CSR request. /// - internal class CsrRequestResponse + internal class ClientCredentialRequestResponse { [JsonProperty("client_id")] public string ClientId { get; } @@ -33,16 +33,16 @@ internal class CsrRequestResponse [JsonProperty("refresh_in")] public int RefreshIn { get; } - public CsrRequestResponse() { } + public ClientCredentialRequestResponse() { } - public static bool ValidateCsrRequestResponse(CsrRequestResponse csrRequestResponse) + public static bool ValidateCsrRequestResponse(ClientCredentialRequestResponse clientCredentialRequestResponse) { - if (string.IsNullOrEmpty(csrRequestResponse.ClientId) || - string.IsNullOrEmpty(csrRequestResponse.TenantId) || - string.IsNullOrEmpty(csrRequestResponse.ClientCredential) || - string.IsNullOrEmpty(csrRequestResponse.RegionalTokenUrl) || - csrRequestResponse.ExpiresIn <= 0 || - csrRequestResponse.RefreshIn <= 0) + if (string.IsNullOrEmpty(clientCredentialRequestResponse.ClientId) || + string.IsNullOrEmpty(clientCredentialRequestResponse.TenantId) || + string.IsNullOrEmpty(clientCredentialRequestResponse.ClientCredential) || + string.IsNullOrEmpty(clientCredentialRequestResponse.RegionalTokenUrl) || + clientCredentialRequestResponse.ExpiresIn <= 0 || + clientCredentialRequestResponse.RefreshIn <= 0) { return false; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs index 2604a8c31e..c3b05ec34e 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs @@ -8,11 +8,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity { - internal class CsrRequest + internal class Csr { public string Pem { get; } - public CsrRequest(string pem) + public Csr(string pem) { Pem = pem ?? throw new ArgumentNullException(nameof(pem)); } @@ -24,19 +24,19 @@ public CsrRequest(string pem) /// AAD tenant_id. /// CuidInfo object containing required VMID and optional VMSSID. /// CsrRequest containing the PEM CSR. - public static CsrRequest Generate(string clientId, string tenantId, CuidInfo cuid) + public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) { - if (string.IsNullOrWhiteSpace(clientId)) + if (string.IsNullOrEmpty(clientId)) throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); - if (string.IsNullOrWhiteSpace(tenantId)) + if (string.IsNullOrEmpty(tenantId)) throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); if (cuid == null) throw new ArgumentNullException(nameof(cuid)); - if (string.IsNullOrWhiteSpace(cuid.Vmid)) + if (string.IsNullOrEmpty(cuid.Vmid)) throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); - return new CsrRequest(pemCsr); + return new Csr(pemCsr); } /// diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 08bbc9caf4..6787b9dfc9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -17,7 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; - private const string CsrRequestPath = "/metadata/identity/issuecredential"; + private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -195,7 +195,7 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } - private async Task ExecuteCsrRequestAsync( + private async Task ExecuteClientCredentialRequestAsync( RequestContext requestContext, string queryParams, string pem) @@ -214,7 +214,7 @@ private async Task ExecuteCsrRequestAsync( try { response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( - ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrRequestPath, queryParams), + ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, ClientCredentialRequestPath, queryParams), headers, body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, @@ -236,8 +236,8 @@ private async Task ExecuteCsrRequestAsync( (int)response.StatusCode); } - var csrRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!CsrRequestResponse.ValidateCsrRequestResponse(csrRequestResponse)) + var clientCredentialRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!ClientCredentialRequestResponse.ValidateCsrRequestResponse(clientCredentialRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, @@ -247,13 +247,13 @@ private async Task ExecuteCsrRequestAsync( (int)response.StatusCode); } - return csrRequestResponse; + return clientCredentialRequestResponse; } protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csrRequest = CsrRequest.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); var queryParams = $"cid={csrMetadata.Cuid}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) @@ -262,7 +262,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) } queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; - var csrRequestResponse = ExecuteCsrRequestAsync(_requestContext, queryParams, csrRequest.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(_requestContext, queryParams, csr.Pem); throw new NotImplementedException(); } diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 578bb27e45..6c52e6dded 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -82,6 +82,7 @@ + @@ -163,4 +164,8 @@ + + + + diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 2e09ace05e..6851b425e3 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -142,10 +142,10 @@ public void TestCsrGeneration() }; // Generate CSR - var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + var csr = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); // Validate the CSR contents using the helper - CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csr.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [DataTestMethod] @@ -162,7 +162,7 @@ public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId }; Assert.ThrowsException(() => - CsrRequest.Generate(clientId, tenantId, cuid)); + Csr.Generate(clientId, tenantId, cuid)); } [TestMethod] @@ -170,7 +170,7 @@ public void TestCsrGeneration_NullCuid() { // Test with null CUID Assert.ThrowsException(() => - CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); + Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); } [DataTestMethod] @@ -186,7 +186,7 @@ public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) // Should throw ArgumentException since Vmid is required Assert.ThrowsException(() => - CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); + Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); } [DataTestMethod] @@ -201,7 +201,7 @@ public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) }; // Should succeed since Vmssid is optional (Vmid is provided and valid) - var csrRequest = CsrRequest.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); Assert.IsNotNull(csrRequest); Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); From 480ae9ea4174cfa118ef1489a548cefe5f302c2f Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 7 Aug 2025 14:16:57 -0400 Subject: [PATCH 08/41] renamed file --- .../ManagedIdentity/{CsrRequest.cs => Csr.cs} | 0 .../Microsoft.Identity.Client/Microsoft.Identity.Client.csproj | 2 ++ 2 files changed, 2 insertions(+) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{CsrRequest.cs => Csr.cs} (100%) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs similarity index 100% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CsrRequest.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 6c52e6dded..468292d402 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -83,6 +83,7 @@ + @@ -167,5 +168,6 @@ + From 0aa869281000cd7b933fae3781a0d120168e61a6 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 16:36:12 -0400 Subject: [PATCH 09/41] small improvement --- .../ImdsV2ManagedIdentitySource.cs | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 6787b9dfc9..fe49b73266 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -196,33 +196,39 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } private async Task ExecuteClientCredentialRequestAsync( - RequestContext requestContext, - string queryParams, + CuidInfo Cuid, string pem) { + var queryParams = $"cid={Cuid}"; + if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) + { + queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; + } + queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; + var headers = new Dictionary { { "Metadata", "true" }, - { "x-ms-client-request-id", requestContext.CorrelationId.ToString() } + { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); HttpResponse response = null; try { - response = await requestContext.ServiceBundle.HttpManager.SendRequestAsync( - ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, ClientCredentialRequestPath, queryParams), + response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, ClientCredentialRequestPath, queryParams), headers, body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, - logger: requestContext.Logger, + logger: _requestContext.Logger, doNotThrow: false, mtlsCertificate: null, validateServerCertificate: null, - cancellationToken: requestContext.UserCancellationToken, + cancellationToken: _requestContext.UserCancellationToken, retryPolicy: retryPolicy) .ConfigureAwait(false); } @@ -255,14 +261,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var queryParams = $"cid={csrMetadata.Cuid}"; - if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) - { - queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; - } - queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; - - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(_requestContext, queryParams, csr.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem); throw new NotImplementedException(); } From de24670c9c16c67e7af6d89d319418869d0e3dd9 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 17:38:17 -0400 Subject: [PATCH 10/41] Initial implementation --- .../ImdsV2ManagedIdentitySource.cs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index fe49b73266..3d2a101b66 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -18,6 +18,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; + private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -261,9 +262,22 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); - throw new NotImplementedException(); + ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{clientCredentialRequestResponse.RegionalTokenUrl}/{clientCredentialRequestResponse.TenantId}{AcquireEntraTokenPath}")); + + request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); + + request.BodyParameters.Add("grant_type", clientCredentialRequestResponse.ClientCredential); + request.BodyParameters.Add("scope", "https://management.azure.com/.default"); + if (clientCredentialRequestResponse.ClientId != null) + { + request.BodyParameters.Add("client_id", clientCredentialRequestResponse.ClientId); + } + + request.RequestType = RequestType.Imds; + + return request; } } } From 621c5662b3b5b26b8e1189bc8e6f8ed62798a3e4 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 17:38:56 -0400 Subject: [PATCH 11/41] added missing awaitor for async method --- global.json | 2 +- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/global.json b/global.json index 66e4a5c8a7..e5135e9ff3 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "8.0.404", + "version": "9.0.0", "rollForward": "latestFeature" } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index fe49b73266..c2151a3b66 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -261,7 +261,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem); + var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); throw new NotImplementedException(); } From 068461b344757548f01a734ce8fa88f1e4cabefb Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:17:19 -0400 Subject: [PATCH 12/41] Fixed bugs discovered from unit testing in child branch --- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index c2151a3b66..081e353f48 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -18,6 +18,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; + private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -199,7 +200,7 @@ private async Task ExecuteClientCredentialReque CuidInfo Cuid, string pem) { - var queryParams = $"cid={Cuid}"; + var queryParams = $"cid={JsonHelper.SerializeToJson(Cuid)}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) { queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; @@ -212,6 +213,12 @@ private async Task ExecuteClientCredentialReque { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; + var payload = new + { + pem = pem + }; + var body = JsonHelper.SerializeToJson(payload); + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); @@ -222,7 +229,7 @@ private async Task ExecuteClientCredentialReque response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, ClientCredentialRequestPath, queryParams), headers, - body: new StringContent($"{{\"pem\":\"{pem}\"}}", System.Text.Encoding.UTF8, "application/json"), + body: new StringContent(body, System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, logger: _requestContext.Logger, doNotThrow: false, From 2034b25487a33eb8b42baae6547d900e7064ecef Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:20:36 -0400 Subject: [PATCH 13/41] undid changes to .proj --- .../Microsoft.Identity.Client.csproj | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 468292d402..578bb27e45 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -82,8 +82,6 @@ - - @@ -165,9 +163,4 @@ - - - - - From 2b7486a5d5060e93405bf619cab32c4356bcb90e Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:26:09 -0400 Subject: [PATCH 14/41] undid change to global.json --- global.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/global.json b/global.json index e5135e9ff3..66e4a5c8a7 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "9.0.0", + "version": "8.0.404", "rollForward": "latestFeature" } } From 310c467a41d2a757a0e186096816925208ccdc9e Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:29:33 -0400 Subject: [PATCH 15/41] started unit testing --- .../ClientCredentialRequestResponse.cs | 12 ++--- .../Core/Mocks/MockHelpers.cs | 34 +++++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 51 +++++++++++++++++++ 3 files changed, 91 insertions(+), 6 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index 924dd75ad1..efec6a1487 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -16,22 +16,22 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ClientCredentialRequestResponse { [JsonProperty("client_id")] - public string ClientId { get; } + public string ClientId { get; set; } [JsonProperty("tenant_id")] - public string TenantId { get; } + public string TenantId { get; set; } [JsonProperty("client_credential")] - public string ClientCredential { get; } + public string ClientCredential { get; set; } [JsonProperty("regional_token_url")] - public string RegionalTokenUrl { get; } + public string RegionalTokenUrl { get; set; } [JsonProperty("expires_in")] - public int ExpiresIn { get; } + public int ExpiresIn { get; set; } [JsonProperty("refresh_in")] - public int RefreshIn { get; } + public int RefreshIn { get; set; } public ClientCredentialRequestResponse() { } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index c0c293840a..93af63edd1 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -625,5 +625,39 @@ public static MockHttpMessageHandler MockCsrResponseFailure() // 400 doesn't trigger the retry policy return MockCsrResponse(HttpStatusCode.BadRequest); } + + public static MockHttpMessageHandler MockClientCredentialResponse() + { + IDictionary expectedQueryParams = new Dictionary(); + IDictionary expectedRequestHeaders = new Dictionary(); + expectedQueryParams.Add("cid", "%7B%22vmid%22:%22fake_vmid%22,%22vmssid%22:%22fake_vmssid%22%7D"); + //expectedQueryParams.Add("uaid", "fake_client_id"); + expectedQueryParams.Add("api-version", "2018-02-01"); + expectedRequestHeaders.Add("Metadata", "true"); + + string content = + "{" + + "\"client_id\": \"fake_client_id\"," + + "\"tenant_id\": \"fake_tenant_id\"," + + "\"client_credential\": \"fake_client_credential\"," + + "\"regional_token_url\": \"fake_regional_token_url\"," + + "\"expires_in\": 3600," + + "\"refresh_in\": 1800" + + "}"; + + var handler = new MockHttpMessageHandler() + { + ExpectedUrl = "http://169.254.169.254/metadata/identity/issuecredential", + ExpectedMethod = HttpMethod.Post, + ExpectedQueryParams = expectedQueryParams, + ExpectedRequestHeaders = expectedRequestHeaders, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(content), + } + }; + + return handler; + } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 6851b425e3..aa4bc6e171 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; @@ -18,6 +19,56 @@ public class ImdsV2Tests : TestBase { private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + [TestMethod] + public async Task ImdsV2HappyPathAsync() + { + using (var httpManager = new MockHttpManager()) + { + /*ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + var miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager);*/ + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockClientCredentialResponse()); + + // TODO: finish this. everything has been tested to this point. + /*MockHttpMessageHandler mockHandler = httpManager.AddManagedIdentityMockHandler( + "MachineLearningEndpoint", + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.ImdsV2//, + // userAssignedId: userAssignedId, + // userAssignedIdentityId);*/ + + // this will fail, see TODO above + var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // this will fail, see TODO above + result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + [TestMethod] public async Task GetCsrMetadataAsyncSucceeds() { From 189ff9e9f79db3b0177af10ae3d337aad2f28d8d Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:32:26 -0400 Subject: [PATCH 16/41] added missing sets --- .../ClientCredentialRequestResponse.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index 924dd75ad1..efec6a1487 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -16,22 +16,22 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ClientCredentialRequestResponse { [JsonProperty("client_id")] - public string ClientId { get; } + public string ClientId { get; set; } [JsonProperty("tenant_id")] - public string TenantId { get; } + public string TenantId { get; set; } [JsonProperty("client_credential")] - public string ClientCredential { get; } + public string ClientCredential { get; set; } [JsonProperty("regional_token_url")] - public string RegionalTokenUrl { get; } + public string RegionalTokenUrl { get; set; } [JsonProperty("expires_in")] - public int ExpiresIn { get; } + public int ExpiresIn { get; set; } [JsonProperty("refresh_in")] - public int RefreshIn { get; } + public int RefreshIn { get; set; } public ClientCredentialRequestResponse() { } From 9345e990456a19903695e32953d83b8dd9c69415 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:33:26 -0400 Subject: [PATCH 17/41] merged from parent branch --- global.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/global.json b/global.json index 66e4a5c8a7..e5135e9ff3 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "8.0.404", + "version": "9.0.0", "rollForward": "latestFeature" } } From c72e61b3e8e93a079f05830a6dfd0401ab62ec76 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 8 Aug 2025 19:34:32 -0400 Subject: [PATCH 18/41] undid changes to global.json --- global.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/global.json b/global.json index e5135e9ff3..66e4a5c8a7 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "9.0.0", + "version": "8.0.404", "rollForward": "latestFeature" } } From 92b325fd897a86a9ee674cd443e7dda56b56d9c9 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 11 Aug 2025 15:40:39 -0400 Subject: [PATCH 19/41] Inplemented some feedback --- .../ManagedIdentity/ClientCredentialRequestResponse.cs | 2 +- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs index efec6a1487..22d92566af 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs @@ -35,7 +35,7 @@ internal class ClientCredentialRequestResponse public ClientCredentialRequestResponse() { } - public static bool ValidateCsrRequestResponse(ClientCredentialRequestResponse clientCredentialRequestResponse) + public static bool IsValid(ClientCredentialRequestResponse clientCredentialRequestResponse) { if (string.IsNullOrEmpty(clientCredentialRequestResponse.ClientId) || string.IsNullOrEmpty(clientCredentialRequestResponse.TenantId) || diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 081e353f48..1e54224700 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -18,7 +18,6 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; - private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -250,7 +249,7 @@ private async Task ExecuteClientCredentialReque } var clientCredentialRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!ClientCredentialRequestResponse.ValidateCsrRequestResponse(clientCredentialRequestResponse)) + if (!ClientCredentialRequestResponse.IsValid(clientCredentialRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, From e85fc9a6cb46df3bf5ad14ff89b22e926e6f5de5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Mon, 11 Aug 2025 15:41:42 -0400 Subject: [PATCH 20/41] merged from parent --- .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index a18cf6b087..8297a905a7 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -18,6 +18,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; + private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, From 067c83c58efde7ae49a00dc986607ea5212f97a5 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 14 Aug 2025 14:30:01 -0400 Subject: [PATCH 21/41] Implemented some feedback --- ...ponse.cs => CertificateRequestResponse.cs} | 21 +++++++++---------- .../ImdsV2ManagedIdentitySource.cs | 14 ++++++------- 2 files changed, 17 insertions(+), 18 deletions(-) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ClientCredentialRequestResponse.cs => CertificateRequestResponse.cs} (57%) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs similarity index 57% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs index 22d92566af..4391fba4be 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ClientCredentialRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs @@ -4,8 +4,7 @@ #if SUPPORTS_SYSTEM_TEXT_JSON using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; #else -using Microsoft.Identity.Client.Utils; -using Microsoft.Identity.Json; + using Microsoft.Identity.Json; #endif namespace Microsoft.Identity.Client.ManagedIdentity @@ -13,7 +12,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity /// /// Represents the response for a Managed Identity CSR request. /// - internal class ClientCredentialRequestResponse + internal class CertificateRequestResponse { [JsonProperty("client_id")] public string ClientId { get; set; } @@ -33,16 +32,16 @@ internal class ClientCredentialRequestResponse [JsonProperty("refresh_in")] public int RefreshIn { get; set; } - public ClientCredentialRequestResponse() { } + public CertificateRequestResponse() { } - public static bool IsValid(ClientCredentialRequestResponse clientCredentialRequestResponse) + public static bool IsValid(CertificateRequestResponse certificateRequestResponse) { - if (string.IsNullOrEmpty(clientCredentialRequestResponse.ClientId) || - string.IsNullOrEmpty(clientCredentialRequestResponse.TenantId) || - string.IsNullOrEmpty(clientCredentialRequestResponse.ClientCredential) || - string.IsNullOrEmpty(clientCredentialRequestResponse.RegionalTokenUrl) || - clientCredentialRequestResponse.ExpiresIn <= 0 || - clientCredentialRequestResponse.RefreshIn <= 0) + if (string.IsNullOrEmpty(certificateRequestResponse.ClientId) || + string.IsNullOrEmpty(certificateRequestResponse.TenantId) || + string.IsNullOrEmpty(certificateRequestResponse.ClientCredential) || + string.IsNullOrEmpty(certificateRequestResponse.RegionalTokenUrl) || + certificateRequestResponse.ExpiresIn <= 0 || + certificateRequestResponse.RefreshIn <= 0) { return false; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 1e54224700..08ec8313cf 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -17,7 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; - private const string ClientCredentialRequestPath = "/metadata/identity/issuecredential"; + private const string CertificateRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -195,7 +195,7 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } - private async Task ExecuteClientCredentialRequestAsync( + private async Task ExecuteCertificateRequestAsync( CuidInfo Cuid, string pem) { @@ -226,7 +226,7 @@ private async Task ExecuteClientCredentialReque try { response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( - ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, ClientCredentialRequestPath, queryParams), + ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, CertificateRequestPath, queryParams), headers, body: new StringContent(body, System.Text.Encoding.UTF8, "application/json"), method: HttpMethod.Post, @@ -248,8 +248,8 @@ private async Task ExecuteClientCredentialReque (int)response.StatusCode); } - var clientCredentialRequestResponse = JsonHelper.DeserializeFromJson(response.Body); - if (!ClientCredentialRequestResponse.IsValid(clientCredentialRequestResponse)) + var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!CertificateRequestResponse.IsValid(certificateRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, @@ -259,7 +259,7 @@ private async Task ExecuteClientCredentialReque (int)response.StatusCode); } - return clientCredentialRequestResponse; + return certificateRequestResponse; } protected override ManagedIdentityRequest CreateRequest(string resource) @@ -267,7 +267,7 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); - var clientCredentialRequestResponse = ExecuteClientCredentialRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); throw new NotImplementedException(); } From f7d6f881386099c1f776e6362783fb6f76e0c4fe Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 15 Aug 2025 13:54:41 -0400 Subject: [PATCH 22/41] PKCS1 -> Pss padding --- .../ManagedIdentity/Csr.cs | 78 ++++++++++++++----- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index c3b05ec34e..5599231411 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -67,18 +67,29 @@ private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidIn } /// - /// Creates a 2048-bit RSA key pair compatible with all target frameworks. + /// Creates a 2048-bit RSA key pair that supports PSS padding across all target frameworks. /// + /// + /// On .NET Framework 4.6.2/4.7.2 (Windows-only), explicitly uses RSACng (also Windows-only) + /// to ensure PSS padding support, as RSA.Create() may return RSACryptoServiceProvider + /// which doesn't support PSS. + /// On .NET Standard 2.0 and .NET 8.0+ (cross-platform), uses RSA.Create() which returns + /// modern implementations that support PSS: RSACng on Windows, OpenSSL-based on Linux/macOS. + /// + /// An RSA instance configured for 2048-bit keys with PSS padding capability. private static RSA CreateRsaKeyPair() { + RSA rsa = null; + #if NET462 || NET472 - var rsa = new RSACryptoServiceProvider(2048); - return rsa; + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new System.Security.Cryptography.RSACng(); #else - var rsa = RSA.Create(); + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif rsa.KeySize = 2048; return rsa; -#endif } /// @@ -167,31 +178,56 @@ private static byte[] BuildAttributes(CuidInfo cuid) } /// - /// Builds the signature algorithm identifier for SHA256withRSA. + /// Builds the signature algorithm identifier for RSASSA-PSS with SHA256. /// private static byte[] BuildSignatureAlgorithmIdentifier() { - byte[] sha256WithRsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 11 }); // SHA256withRSA OID - byte[] nullParam = EncodeAsn1Null(); - return EncodeAsn1Sequence(new[] { sha256WithRsaOid, nullParam }); + byte[] rsassaPssOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 10 }); // RSASSA-PSS OID + byte[] pssParams = BuildPssParameters(); + return EncodeAsn1Sequence(new[] { rsassaPssOid, pssParams }); } /// - /// Signs the CertificationRequestInfo with SHA256withRSA. + /// Builds the RSASSA-PSS parameters for SHA256 with MGF1. + /// + private static byte[] BuildPssParameters() + { + var parameters = new System.Collections.Generic.List(); + + // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 + // We explicitly specify SHA256 since default is SHA1 + byte[] sha256Oid = EncodeAsn1ObjectIdentifier(new int[] { 2, 16, 840, 1, 101, 3, 4, 2, 1 }); // SHA256 OID + byte[] sha256Null = EncodeAsn1Null(); + byte[] hashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); + byte[] hashAlgorithmParam = EncodeAsn1ContextSpecific(0, hashAlgorithm); + parameters.Add(hashAlgorithmParam); + + // maskGenAlgorithm [1] AlgorithmIdentifier DEFAULT mgf1SHA1 + // We explicitly specify MGF1 with SHA256 + byte[] mgf1Oid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 8 }); // MGF1 OID + byte[] mgf1HashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); // MGF1 uses SHA256 + byte[] maskGenAlgorithm = EncodeAsn1Sequence(new[] { mgf1Oid, mgf1HashAlgorithm }); + byte[] maskGenAlgorithmParam = EncodeAsn1ContextSpecific(1, maskGenAlgorithm); + parameters.Add(maskGenAlgorithmParam); + + // saltLength [2] INTEGER DEFAULT 20 + // We explicitly specify 32 for SHA256 (hash length) + byte[] saltLength = EncodeAsn1Integer(32); + byte[] saltLengthParam = EncodeAsn1ContextSpecific(2, saltLength); + parameters.Add(saltLengthParam); + + // trailerField [3] INTEGER DEFAULT 1 + // Default value is 1 (0xBC), so we omit this parameter + + return EncodeAsn1Sequence(parameters.ToArray()); + } + + /// + /// Signs the CertificationRequestInfo with SHA256withRSA-PSS. /// private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) { -#if NET462 || NET472 - using (var sha256 = SHA256.Create()) - { - byte[] hash = sha256.ComputeHash(certificationRequestInfo); - var formatter = new RSAPKCS1SignatureFormatter(rsa); - formatter.SetHashAlgorithm("SHA256"); - return formatter.CreateSignature(hash); - } -#else - return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); -#endif + return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); } /// From 74e8e606d014439c750156b65fd628d07e2e1ae7 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 15 Aug 2025 15:13:04 -0400 Subject: [PATCH 23/41] re-used imports --- .../ManagedIdentity/Csr.cs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index 5599231411..c35a93e4a9 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Security.Cryptography; using System.Text; using Microsoft.Identity.Client.Utils; @@ -83,7 +84,7 @@ private static RSA CreateRsaKeyPair() #if NET462 || NET472 // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available - rsa = new System.Security.Cryptography.RSACng(); + rsa = new RSACng(); #else // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation rsa = RSA.Create(); @@ -97,7 +98,7 @@ private static RSA CreateRsaKeyPair() /// private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) { - var components = new System.Collections.Generic.List(); + var components = new List(); // Version (INTEGER 0) components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); @@ -119,7 +120,7 @@ private static byte[] BuildCertificationRequestInfo(string clientId, string tena /// private static byte[] BuildSubjectName(string clientId, string tenantId) { - var rdnSequence = new System.Collections.Generic.List(); + var rdnSequence = new List(); // CN= byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID @@ -163,7 +164,7 @@ private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) /// private static byte[] BuildAttributes(CuidInfo cuid) { - var attributes = new System.Collections.Generic.List(); + var attributes = new List(); // CUID attribute (OID 1.2.840.113549.1.9.7) // Serialize CuidInfo as JSON object string using existing JSON serialization @@ -192,7 +193,7 @@ private static byte[] BuildSignatureAlgorithmIdentifier() /// private static byte[] BuildPssParameters() { - var parameters = new System.Collections.Generic.List(); + var parameters = new List(); // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 // We explicitly specify SHA256 since default is SHA1 @@ -309,7 +310,7 @@ private static byte[] EncodeAsn1Integer(int value) if (value == 0) return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); - var bytes = new System.Collections.Generic.List(); + var bytes = new List(); int temp = value; while (temp > 0) { @@ -365,7 +366,7 @@ private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) if (oid == null || oid.Length < 2) throw new ArgumentException("OID must have at least 2 components"); - var bytes = new System.Collections.Generic.List(); + var bytes = new List(); // First two components are encoded as (first * 40 + second) bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); @@ -411,7 +412,7 @@ private static byte[] EncodeAsn1Length(int length) return new byte[] { (byte)length }; } - var lengthBytes = new System.Collections.Generic.List(); + var lengthBytes = new List(); int temp = length; while (temp > 0) { @@ -433,7 +434,7 @@ private static byte[] EncodeOidComponent(int value) if (value == 0) return new byte[] { 0x00 }; - var bytes = new System.Collections.Generic.List(); + var bytes = new List(); int temp = value; bytes.Insert(0, (byte)(temp & 0x7F)); From 152f396704046f77d37e8fca55b7c82993156504 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 15 Aug 2025 15:51:12 -0400 Subject: [PATCH 24/41] Implemented feedback --- .../ManagedIdentity/Csr.cs | 1 + .../ManagedIdentity/ImdsV2ManagedIdentitySource.cs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index c35a93e4a9..fdc8584cd3 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -80,6 +80,7 @@ private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidIn /// An RSA instance configured for 2048-bit keys with PSS padding capability. private static RSA CreateRsaKeyPair() { + // TODO: use the strongest key on the machine i.e. a TPM key RSA rsa = null; #if NET462 || NET472 diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 08ec8313cf..4d5354dcc6 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -242,18 +242,28 @@ private async Task ExecuteCertificateRequestAsync( { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, - $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCsrRequest failed.", + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed.", ex, ManagedIdentitySource.ImdsV2, (int)response.StatusCode); } + if (response.StatusCode != HttpStatusCode.OK) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); if (!CertificateRequestResponse.IsValid(certificateRequestResponse)) { throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, - $"[ImdsV2] ImdsV2ManagedIdentitySource.GetCsrMetadataAsync failed because the CsrMetadata response is invalid. Status code: {response.StatusCode} Body: {response.Body}", + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed because the certificate request response is malformed. Status code: {response.StatusCode}", null, ManagedIdentitySource.ImdsV2, (int)response.StatusCode); From d46c853ba31e2d903d32aca5026ff566c41573f2 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Tue, 19 Aug 2025 14:36:28 -0400 Subject: [PATCH 25/41] Changes from manual testing. --- .../ManagedIdentity/Csr.cs | 11 +++-- .../ManagedIdentity/CsrMetadata.cs | 18 ++++---- .../ImdsV2ManagedIdentitySource.cs | 31 ++++++------- .../net/MsalJsonSerializerContext.cs | 5 ++ .../Core/Mocks/MockHelpers.cs | 6 +-- .../TestConstants.cs | 4 +- .../ManagedIdentityTests/CsrValidator.cs | 2 +- .../ManagedIdentityTests/ImdsV2Tests.cs | 46 ++++++++++--------- 8 files changed, 68 insertions(+), 55 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs index fdc8584cd3..899304b544 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -9,6 +9,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity { + internal class PemPayload + { + public string pem { get; set; } + } + internal class Csr { public string Pem { get; } @@ -23,7 +28,7 @@ public Csr(string pem) /// /// Managed Identity client_id. /// AAD tenant_id. - /// CuidInfo object containing required VMID and optional VMSSID. + /// CuidInfo object containing required vmId and optional vmssId. /// CsrRequest containing the PEM CSR. public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) { @@ -33,8 +38,8 @@ public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); if (cuid == null) throw new ArgumentNullException(nameof(cuid)); - if (string.IsNullOrEmpty(cuid.Vmid)) - throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); + if (string.IsNullOrEmpty(cuid.VmId)) + throw new ArgumentException("cuid.VmId must not be null or empty.", nameof(cuid.VmId)); string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); return new Csr(pemCsr); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs index a831d02c7a..5de9a5e490 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs @@ -14,11 +14,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity /// internal class CuidInfo { - [JsonProperty("vmid")] - public string Vmid { get; set; } + [JsonProperty("vmId")] + public string VmId { get; set; } - [JsonProperty("vmssid")] - public string Vmssid { get; set; } + [JsonProperty("vmssId")] + public string VmssId { get; set; } } /// @@ -29,8 +29,8 @@ internal class CsrMetadata /// /// VM unique Id /// - [JsonProperty("cuid")] - public CuidInfo Cuid { get; set; } + [JsonProperty("cuId")] + public CuidInfo CuId { get; set; } /// /// client_id of the Managed Identity @@ -57,12 +57,12 @@ public CsrMetadata() { } /// Validates a JSON decoded CsrMetadata instance. /// /// The CsrMetadata object. - /// false if any required field is null. Note: Vmid is required, Vmssid is optional. + /// false if any required field is null. Note: VmId is required, VmssId is optional. public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null || - csrMetadata.Cuid == null || - string.IsNullOrEmpty(csrMetadata.Cuid.Vmid) || + csrMetadata.CuId == null || + string.IsNullOrEmpty(csrMetadata.CuId.VmId) || string.IsNullOrEmpty(csrMetadata.ClientId) || string.IsNullOrEmpty(csrMetadata.TenantId) || string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 4d5354dcc6..bd0d2c2e4a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -16,15 +16,16 @@ namespace Microsoft.Identity.Client.ManagedIdentity { internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { - private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; + public const string ImdsV2ApiVersion = "2.0"; + private const string CsrMetadataPath = "/metadata/identity/getplatformmetadata"; private const string CertificateRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, bool probeMode) { - string queryParams = $"cred-api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; - + string queryParams = $"cred-api-version={ImdsV2ApiVersion}"; + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam( requestContext.ServiceBundle.Config.ManagedIdentityId.IdType, requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId, @@ -129,12 +130,14 @@ private static bool ValidateCsrMetadataResponse( * "1556" // index 1: captured group (\d+) * ] */ - string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; + // Imds bug: headers are missing + // TODO: uncomment this when the bug is fixed + /*string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; if (serverHeader == null) { if (probeMode) { - logger.Info(() => "[Managed Identity] IMDSv2 managed identity is not available. 'server' header is missing from the CSR metadata response."); + logger.Info(() => $"[Managed Identity] IMDSv2 managed identity is not available. 'server' header is missing from the CSR metadata response. Body: {response.Body}"); return false; } else @@ -164,7 +167,7 @@ private static bool ValidateCsrMetadataResponse( null, (int)response.StatusCode); } - } + }*/ return true; } @@ -196,26 +199,22 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } private async Task ExecuteCertificateRequestAsync( - CuidInfo Cuid, + CuidInfo cuid, string pem) { - var queryParams = $"cid={JsonHelper.SerializeToJson(Cuid)}"; + var queryParams = $"cuid={JsonHelper.SerializeToJson(cuid)}&cred-api-version={ImdsV2ApiVersion}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) { queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; } - queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; var headers = new Dictionary { { "Metadata", "true" }, { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - - var payload = new - { - pem = pem - }; + + var payload = new PemPayload { pem = pem }; var body = JsonHelper.SerializeToJson(payload); IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; @@ -275,9 +274,9 @@ private async Task ExecuteCertificateRequestAsync( protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csr.Pem).GetAwaiter().GetResult(); throw new NotImplementedException(); } diff --git a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs index d36f036282..6d6a6cb7f2 100644 --- a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs +++ b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs @@ -40,6 +40,10 @@ namespace Microsoft.Identity.Client.Platforms.net [JsonSerializable(typeof(ManagedIdentityResponse))] [JsonSerializable(typeof(ManagedIdentityErrorResponse))] [JsonSerializable(typeof(OidcMetadata))] + [JsonSerializable(typeof(CsrMetadata))] + [JsonSerializable(typeof(CuidInfo))] + [JsonSerializable(typeof(CertificateRequestResponse))] + [JsonSerializable(typeof(PemPayload))] [JsonSourceGenerationOptions] internal partial class MsalJsonSerializerContext : JsonSerializerContext { @@ -54,6 +58,7 @@ public static MsalJsonSerializerContext Custom { NumberHandling = JsonNumberHandling.AllowReadingFromString, AllowTrailingCommas = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, Converters = { new JsonStringConverter(), diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index c0c293840a..3cf5e7a64f 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -590,12 +590,12 @@ public static MockHttpMessageHandler MockCsrResponse( { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - expectedQueryParams.Add("cred-api-version", "2018-02-01"); + expectedQueryParams.Add("cred-api-version", "2.0"); expectedRequestHeaders.Add("Metadata", "true"); string content = "{" + - "\"cuid\": { \"vmid\": \"fake_vmid\", \"vmssid\": \"fake_vmssid\" }," + + "\"cuid\": { \"vmId\": \"fake_vmId\", \"vmssId\": \"fake_vmssId\" }," + "\"clientId\": \"fake_client_id\"," + "\"tenantId\": \"fake_tenant_id\"," + "\"attestationEndpoint\": \"fake_attestation_endpoint\"" + @@ -603,7 +603,7 @@ public static MockHttpMessageHandler MockCsrResponse( var handler = new MockHttpMessageHandler() { - ExpectedUrl = "http://169.254.169.254/metadata/identity/getPlatformMetadata", + ExpectedUrl = "http://169.254.169.254/metadata/identity/getplatformmetadata", ExpectedMethod = HttpMethod.Get, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index d4a63354c0..5a2ea2986a 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -154,8 +154,8 @@ public static HashSet s_scope public const string IdentityProvider = "my-idp"; public const string Name = "First Last"; public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; - public const string Vmid = "test-vm-id"; - public const string Vmssid = "test-vmss-id"; + public const string VmId = "test-vm-id"; + public const string VmssId = "test-vmss-id"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index 671700c100..adbc2e298d 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -302,7 +302,7 @@ private static void ValidatePublicKey(byte[] publicKeyBytes) /// /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. - /// Note: Vmid is required, Vmssid is optional and will be omitted if null/empty. + /// Note: VmId is required, VmssId is optional and will be omitted if null/empty. /// private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 6851b425e3..312d43ec74 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -60,7 +60,9 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() } } - [TestMethod] + // Imds bug: headers are missing + // TODO: uncomment this when the bug is fixed + /*[TestMethod] public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { using (var httpManager = new MockHttpManager()) @@ -75,9 +77,11 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - } + }*/ - [TestMethod] + // Imds bug: headers are missing + // TODO: uncomment this when the bug is fixed + /*[TestMethod] public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() { using (var httpManager = new MockHttpManager()) @@ -92,7 +96,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - } + }*/ [TestMethod] public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() @@ -137,8 +141,8 @@ public void TestCsrGeneration() { var cuid = new CuidInfo { - Vmid = TestConstants.Vmid, - Vmssid = TestConstants.Vmssid + VmId = TestConstants.VmId, + VmssId = TestConstants.VmssId }; // Generate CSR @@ -157,8 +161,8 @@ public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId { var cuid = new CuidInfo { - Vmid = TestConstants.Vmid, - Vmssid = TestConstants.Vmssid + VmId = TestConstants.VmId, + //VmssId = TestConstants.VmssId }; Assert.ThrowsException(() => @@ -174,38 +178,38 @@ public void TestCsrGeneration_NullCuid() } [DataTestMethod] - [DataRow(null, TestConstants.Vmssid)] - [DataRow("", TestConstants.Vmssid)] - public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) + [DataRow(null, TestConstants.VmssId)] + [DataRow("", TestConstants.VmssId)] + public void TestCsrGeneration_InvalidVmId(string vmId, string vmssId) { var cuid = new CuidInfo { - Vmid = vmid, - Vmssid = vmssid + VmId = vmId, + //VmssId = vmssId }; - // Should throw ArgumentException since Vmid is required + // Should throw ArgumentException since VmId is required Assert.ThrowsException(() => Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); } [DataTestMethod] - [DataRow(TestConstants.Vmid, null)] - [DataRow(TestConstants.Vmid, "")] - public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) + [DataRow(TestConstants.VmId, null)] + [DataRow(TestConstants.VmId, "")] + public void TestCsrGeneration_OptionalVmssId(string vmId, string vmssId) { var cuid = new CuidInfo { - Vmid = vmid, - Vmssid = vmssid + VmId = vmId, + //VmssId = vmssId }; - // Should succeed since Vmssid is optional (Vmid is provided and valid) + // Should succeed since VmssId is optional (VmId is provided and valid) var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); Assert.IsNotNull(csrRequest); Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); - // Validate the CSR contents - this should handle null/empty VMSSID gracefully + // Validate the CSR contents - this should handle null/empty vmssId gracefully CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); } From 3f75e3ad3c767246306c68aa66e89f4b6d2dc878 Mon Sep 17 00:00:00 2001 From: Robbie-Microsoft <87724641+Robbie-Microsoft@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:27:25 -0400 Subject: [PATCH 26/41] ImdsV2: Reworked Custom ASN1 Encoder to use System.Formats.Asn1 Nuget Package (#5449) --- Directory.Packages.props | 2 +- .../ManagedIdentity/Csr.cs | 482 ------------------ .../ManagedIdentity/ManagedIdentityClient.cs | 1 + .../ManagedIdentity/V2/CertificateRequest.cs | 235 +++++++++ .../{ => V2}/CertificateRequestResponse.cs | 2 +- .../ManagedIdentity/V2/Csr.cs | 51 ++ .../ManagedIdentity/{ => V2}/CsrMetadata.cs | 2 +- .../{ => V2}/ImdsV2ManagedIdentitySource.cs | 11 +- .../Microsoft.Identity.Client.csproj | 2 + .../net/MsalJsonSerializerContext.cs | 2 +- .../ManagedIdentityTests/CsrValidator.cs | 457 +++-------------- .../ManagedIdentityTests/ImdsV2Tests.cs | 70 +-- 12 files changed, 388 insertions(+), 929 deletions(-) delete mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ => V2}/CertificateRequestResponse.cs (96%) create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ => V2}/CsrMetadata.cs (97%) rename src/client/Microsoft.Identity.Client/ManagedIdentity/{ => V2}/ImdsV2ManagedIdentitySource.cs (97%) diff --git a/Directory.Packages.props b/Directory.Packages.props index f222636fe1..d8f368dfc9 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -17,6 +17,7 @@ + @@ -80,6 +81,5 @@ - diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs deleted file mode 100644 index 899304b544..0000000000 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Security.Cryptography; -using System.Text; -using Microsoft.Identity.Client.Utils; - -namespace Microsoft.Identity.Client.ManagedIdentity -{ - internal class PemPayload - { - public string pem { get; set; } - } - - internal class Csr - { - public string Pem { get; } - - public Csr(string pem) - { - Pem = pem ?? throw new ArgumentNullException(nameof(pem)); - } - - /// - /// Generates a CSR for the given client, tenant, and CUID info. - /// - /// Managed Identity client_id. - /// AAD tenant_id. - /// CuidInfo object containing required vmId and optional vmssId. - /// CsrRequest containing the PEM CSR. - public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) - { - if (string.IsNullOrEmpty(clientId)) - throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); - if (string.IsNullOrEmpty(tenantId)) - throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); - if (cuid == null) - throw new ArgumentNullException(nameof(cuid)); - if (string.IsNullOrEmpty(cuid.VmId)) - throw new ArgumentException("cuid.VmId must not be null or empty.", nameof(cuid.VmId)); - - string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); - return new Csr(pemCsr); - } - - /// - /// Generates a PKCS#10 Certificate Signing Request in PEM format. - /// - private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidInfo cuid) - { - // Generate RSA key pair (2048-bit) - RSA rsa = CreateRsaKeyPair(); - - try - { - // Build the CSR components - byte[] certificationRequestInfo = BuildCertificationRequestInfo(clientId, tenantId, cuid, rsa); - byte[] signatureAlgorithm = BuildSignatureAlgorithmIdentifier(); - byte[] signature = SignCertificationRequestInfo(certificationRequestInfo, rsa); - - // Combine into final CSR structure - byte[] csrBytes = BuildFinalCsr(certificationRequestInfo, signatureAlgorithm, signature); - - // Convert to PEM format - return ConvertToPem(csrBytes); - } - finally - { - rsa?.Dispose(); - } - } - - /// - /// Creates a 2048-bit RSA key pair that supports PSS padding across all target frameworks. - /// - /// - /// On .NET Framework 4.6.2/4.7.2 (Windows-only), explicitly uses RSACng (also Windows-only) - /// to ensure PSS padding support, as RSA.Create() may return RSACryptoServiceProvider - /// which doesn't support PSS. - /// On .NET Standard 2.0 and .NET 8.0+ (cross-platform), uses RSA.Create() which returns - /// modern implementations that support PSS: RSACng on Windows, OpenSSL-based on Linux/macOS. - /// - /// An RSA instance configured for 2048-bit keys with PSS padding capability. - private static RSA CreateRsaKeyPair() - { - // TODO: use the strongest key on the machine i.e. a TPM key - RSA rsa = null; - -#if NET462 || NET472 - // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available - rsa = new RSACng(); -#else - // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation - rsa = RSA.Create(); -#endif - rsa.KeySize = 2048; - return rsa; - } - - /// - /// Builds the CertificationRequestInfo structure containing subject, public key, and attributes. - /// - private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) - { - var components = new List(); - - // Version (INTEGER 0) - components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); - - // Subject: CN=, DC= - components.Add(BuildSubjectName(clientId, tenantId)); - - // SubjectPublicKeyInfo - components.Add(BuildSubjectPublicKeyInfo(rsa)); - - // Attributes (including CUID) - components.Add(BuildAttributes(cuid)); - - return EncodeAsn1Sequence(components.ToArray()); - } - - /// - /// Builds the X.500 Distinguished Name for the subject. - /// - private static byte[] BuildSubjectName(string clientId, string tenantId) - { - var rdnSequence = new List(); - - // CN= - byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID - byte[] cnValue = EncodeAsn1Utf8String(clientId); - byte[] cnAttributeValue = EncodeAsn1Sequence(new[] { cnOid, cnValue }); - rdnSequence.Add(EncodeAsn1Set(new[] { cnAttributeValue })); - - // DC= - byte[] dcOid = EncodeAsn1ObjectIdentifier(new int[] { 0, 9, 2342, 19200300, 100, 1, 25 }); // domainComponent OID - byte[] dcValue = EncodeAsn1Utf8String(tenantId); - byte[] dcAttributeValue = EncodeAsn1Sequence(new[] { dcOid, dcValue }); - rdnSequence.Add(EncodeAsn1Set(new[] { dcAttributeValue })); - - return EncodeAsn1Sequence(rdnSequence.ToArray()); - } - - /// - /// Builds the SubjectPublicKeyInfo structure containing the RSA public key. - /// - private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) - { - RSAParameters rsaParams = rsa.ExportParameters(false); - - // RSA Public Key structure - byte[] modulus = EncodeAsn1Integer(rsaParams.Modulus); - byte[] exponent = EncodeAsn1Integer(rsaParams.Exponent); - byte[] rsaPublicKey = EncodeAsn1Sequence(new[] { modulus, exponent }); - - // Algorithm identifier for RSA encryption - byte[] rsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 1 }); // RSA encryption OID - byte[] nullParam = EncodeAsn1Null(); - byte[] algorithmIdentifier = EncodeAsn1Sequence(new[] { rsaOid, nullParam }); - - // SubjectPublicKeyInfo - byte[] publicKeyBitString = EncodeAsn1BitString(rsaPublicKey); - return EncodeAsn1Sequence(new[] { algorithmIdentifier, publicKeyBitString }); - } - - /// - /// Builds the attributes section including the CUID extension. - /// - private static byte[] BuildAttributes(CuidInfo cuid) - { - var attributes = new List(); - - // CUID attribute (OID 1.2.840.113549.1.9.7) - // Serialize CuidInfo as JSON object string using existing JSON serialization - byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); - string cuidValue = JsonHelper.SerializeToJson(cuid); - byte[] cuidData = EncodeAsn1PrintableString(cuidValue); - byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); - byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); - attributes.Add(cuidAttribute); - - return EncodeAsn1ContextSpecific(0, EncodeAsn1SequenceRaw(attributes.ToArray())); - } - - /// - /// Builds the signature algorithm identifier for RSASSA-PSS with SHA256. - /// - private static byte[] BuildSignatureAlgorithmIdentifier() - { - byte[] rsassaPssOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 10 }); // RSASSA-PSS OID - byte[] pssParams = BuildPssParameters(); - return EncodeAsn1Sequence(new[] { rsassaPssOid, pssParams }); - } - - /// - /// Builds the RSASSA-PSS parameters for SHA256 with MGF1. - /// - private static byte[] BuildPssParameters() - { - var parameters = new List(); - - // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 - // We explicitly specify SHA256 since default is SHA1 - byte[] sha256Oid = EncodeAsn1ObjectIdentifier(new int[] { 2, 16, 840, 1, 101, 3, 4, 2, 1 }); // SHA256 OID - byte[] sha256Null = EncodeAsn1Null(); - byte[] hashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); - byte[] hashAlgorithmParam = EncodeAsn1ContextSpecific(0, hashAlgorithm); - parameters.Add(hashAlgorithmParam); - - // maskGenAlgorithm [1] AlgorithmIdentifier DEFAULT mgf1SHA1 - // We explicitly specify MGF1 with SHA256 - byte[] mgf1Oid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 8 }); // MGF1 OID - byte[] mgf1HashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); // MGF1 uses SHA256 - byte[] maskGenAlgorithm = EncodeAsn1Sequence(new[] { mgf1Oid, mgf1HashAlgorithm }); - byte[] maskGenAlgorithmParam = EncodeAsn1ContextSpecific(1, maskGenAlgorithm); - parameters.Add(maskGenAlgorithmParam); - - // saltLength [2] INTEGER DEFAULT 20 - // We explicitly specify 32 for SHA256 (hash length) - byte[] saltLength = EncodeAsn1Integer(32); - byte[] saltLengthParam = EncodeAsn1ContextSpecific(2, saltLength); - parameters.Add(saltLengthParam); - - // trailerField [3] INTEGER DEFAULT 1 - // Default value is 1 (0xBC), so we omit this parameter - - return EncodeAsn1Sequence(parameters.ToArray()); - } - - /// - /// Signs the CertificationRequestInfo with SHA256withRSA-PSS. - /// - private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) - { - return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); - } - - /// - /// Combines all components into the final CSR structure. - /// - private static byte[] BuildFinalCsr(byte[] certificationRequestInfo, byte[] signatureAlgorithm, byte[] signature) - { - byte[] signatureBitString = EncodeAsn1BitString(signature); - return EncodeAsn1Sequence(new[] { certificationRequestInfo, signatureAlgorithm, signatureBitString }); - } - - /// - /// Converts DER-encoded bytes to PEM format. - /// - private static string ConvertToPem(byte[] derBytes) - { - string base64 = Convert.ToBase64String(derBytes); - var sb = new StringBuilder(); - sb.AppendLine("-----BEGIN CERTIFICATE REQUEST-----"); - - // Split into 64-character lines - for (int i = 0; i < base64.Length; i += 64) - { - int length = Math.Min(64, base64.Length - i); - sb.AppendLine(base64.Substring(i, length)); - } - - sb.AppendLine("-----END CERTIFICATE REQUEST-----"); - return sb.ToString(); - } - - #region ASN.1 Encoding Helpers - - /// - /// Encodes an ASN.1 SEQUENCE. - /// - private static byte[] EncodeAsn1Sequence(byte[][] components) - { - return EncodeAsn1Tag(0x30, ConcatenateByteArrays(components)); - } - - /// - /// Encodes an ASN.1 SEQUENCE without the outer tag (for raw concatenation). - /// - private static byte[] EncodeAsn1SequenceRaw(byte[][] components) - { - return ConcatenateByteArrays(components); - } - - /// - /// Encodes an ASN.1 SET. - /// - private static byte[] EncodeAsn1Set(byte[][] components) - { - return EncodeAsn1Tag(0x31, ConcatenateByteArrays(components)); - } - - /// - /// Encodes an ASN.1 INTEGER. - /// - private static byte[] EncodeAsn1Integer(byte[] value) - { - // Ensure positive integer (prepend 0x00 if high bit is set) - if (value != null && value.Length > 0 && (value[0] & 0x80) != 0) - { - byte[] paddedValue = new byte[value.Length + 1]; - paddedValue[0] = 0x00; - Array.Copy(value, 0, paddedValue, 1, value.Length); - value = paddedValue; - } - return EncodeAsn1Tag(0x02, value ?? new byte[0]); - } - - /// - /// Encodes an ASN.1 INTEGER from an integer value. - /// - private static byte[] EncodeAsn1Integer(int value) - { - if (value == 0) - return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); - - var bytes = new List(); - int temp = value; - while (temp > 0) - { - bytes.Insert(0, (byte)(temp & 0xFF)); - temp >>= 8; - } - - return EncodeAsn1Integer(bytes.ToArray()); - } - - /// - /// Encodes an ASN.1 BIT STRING. - /// - private static byte[] EncodeAsn1BitString(byte[] value) - { - byte[] bitStringValue = new byte[value.Length + 1]; - bitStringValue[0] = 0x00; // No unused bits - Array.Copy(value, 0, bitStringValue, 1, value.Length); - return EncodeAsn1Tag(0x03, bitStringValue); - } - - /// - /// Encodes an ASN.1 UTF8String. - /// - private static byte[] EncodeAsn1Utf8String(string value) - { - byte[] utf8Bytes = Encoding.UTF8.GetBytes(value); - return EncodeAsn1Tag(0x0C, utf8Bytes); - } - - /// - /// Encodes an ASN.1 PrintableString. - /// - private static byte[] EncodeAsn1PrintableString(string value) - { - byte[] asciiBytes = Encoding.ASCII.GetBytes(value); - return EncodeAsn1Tag(0x13, asciiBytes); - } - - /// - /// Encodes an ASN.1 NULL. - /// - private static byte[] EncodeAsn1Null() - { - return new byte[] { 0x05, 0x00 }; - } - - /// - /// Encodes an ASN.1 OBJECT IDENTIFIER. - /// - private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) - { - if (oid == null || oid.Length < 2) - throw new ArgumentException("OID must have at least 2 components"); - - var bytes = new List(); - - // First two components are encoded as (first * 40 + second) - bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); - - // Remaining components - for (int i = 2; i < oid.Length; i++) - { - bytes.AddRange(EncodeOidComponent(oid[i])); - } - - return EncodeAsn1Tag(0x06, bytes.ToArray()); - } - - /// - /// Encodes an ASN.1 context-specific tag. - /// - private static byte[] EncodeAsn1ContextSpecific(int tagNumber, byte[] content) - { - byte tag = (byte)(0xA0 | tagNumber); // Context-specific, constructed - return EncodeAsn1Tag(tag, content); - } - - /// - /// Encodes an ASN.1 tag with length and content. - /// - private static byte[] EncodeAsn1Tag(byte tag, byte[] content) - { - byte[] lengthBytes = EncodeAsn1Length(content.Length); - byte[] result = new byte[1 + lengthBytes.Length + content.Length]; - result[0] = tag; - Array.Copy(lengthBytes, 0, result, 1, lengthBytes.Length); - Array.Copy(content, 0, result, 1 + lengthBytes.Length, content.Length); - return result; - } - - /// - /// Encodes ASN.1 length field. - /// - private static byte[] EncodeAsn1Length(int length) - { - if (length < 0x80) - { - return new byte[] { (byte)length }; - } - - var lengthBytes = new List(); - int temp = length; - while (temp > 0) - { - lengthBytes.Insert(0, (byte)(temp & 0xFF)); - temp >>= 8; - } - - byte[] result = new byte[lengthBytes.Count + 1]; - result[0] = (byte)(0x80 | lengthBytes.Count); - lengthBytes.CopyTo(result, 1); - return result; - } - - /// - /// Encodes a single OID component using variable-length encoding. - /// - private static byte[] EncodeOidComponent(int value) - { - if (value == 0) - return new byte[] { 0x00 }; - - var bytes = new List(); - int temp = value; - - bytes.Insert(0, (byte)(temp & 0x7F)); - temp >>= 7; - - while (temp > 0) - { - bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); - temp >>= 7; - } - - return bytes.ToArray(); - } - - /// - /// Concatenates multiple byte arrays. - /// - private static byte[] ConcatenateByteArrays(byte[][] arrays) - { - int totalLength = 0; - foreach (byte[] array in arrays) - { - totalLength += array.Length; - } - - byte[] result = new byte[totalLength]; - int offset = 0; - foreach (byte[] array in arrays) - { - Array.Copy(array, 0, result, offset, array.Length); - offset += array.Length; - } - - return result; - } - - #endregion - } -} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index 35f334d4bf..4e8e7fd8ce 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -9,6 +9,7 @@ using Microsoft.Identity.Client.PlatformsCommon.Shared; using System.IO; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.ManagedIdentity.V2; namespace Microsoft.Identity.Client.ManagedIdentity { diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs new file mode 100644 index 0000000000..81a2de8d30 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequest.cs @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.ObjectModel; +using System.Formats.Asn1; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class CertificateRequest + { + private X500DistinguishedName _subjectName; + private RSA _rsa; + private HashAlgorithmName _hashAlgorithmName; + private RSASignaturePadding _rsaPadding; + + internal CertificateRequest( + X500DistinguishedName subjectName, + RSA key, + HashAlgorithmName hashAlgorithm, + RSASignaturePadding padding) + { + _subjectName = subjectName; + _rsa = key; + _hashAlgorithmName = hashAlgorithm; + _rsaPadding = padding; + } + + internal Collection OtherRequestAttributes { get; } = new Collection(); + + private static string MakePem(byte[] ber, string header) + { + const int LineLength = 64; + + string base64 = Convert.ToBase64String(ber); + int offset = 0; + + StringBuilder builder = new StringBuilder("-----BEGIN "); + builder.Append(header); + builder.AppendLine("-----"); + + while (offset < base64.Length) + { + int lineEnd = Math.Min(offset + LineLength, base64.Length); + builder.AppendLine(base64.Substring(offset, lineEnd - offset)); + offset = lineEnd; + } + + builder.Append("-----END "); + builder.Append(header); + builder.AppendLine("-----"); + + return builder.ToString(); + } + + internal string CreateSigningRequestPem() + { + byte[] csr = CreateSigningRequest(); + return MakePem(csr, "CERTIFICATE REQUEST"); + } + + internal byte[] CreateSigningRequest() + { + if (_hashAlgorithmName != HashAlgorithmName.SHA256) + { + throw new NotSupportedException("Signature Processing has only been written for SHA256"); + } + + AsnWriter writer = new AsnWriter(AsnEncodingRules.DER); + + // RSAPublicKey ::= SEQUENCE { + // modulus INTEGER, -- n + // publicExponent INTEGER -- e + // } + + using (writer.PushSequence()) + { + RSAParameters rsaParameters = _rsa.ExportParameters(false); + writer.WriteIntegerUnsigned(rsaParameters.Modulus); + writer.WriteIntegerUnsigned(rsaParameters.Exponent); + } + + byte[] publicKey = writer.Encode(); + writer.Reset(); + + // CertificationRequestInfo ::= SEQUENCE { + // version INTEGER { v1(0) } (v1,...), + // subject Name, + // subjectPKInfo SubjectPublicKeyInfo{{ PKInfoAlgorithms }}, + // attributes [0] Attributes{{ CRIAttributes }} + // } + // + // SubjectPublicKeyInfo { ALGORITHM: IOSet} ::= SEQUENCE { + // algorithm AlgorithmIdentifier { { IOSet} }, + // subjectPublicKey BIT STRING + // } + // + // Attributes { ATTRIBUTE:IOSet } ::= SET OF Attribute{{ IOSet }} + // + // Attribute { ATTRIBUTE:IOSet } ::= SEQUENCE { + // type ATTRIBUTE.&id({IOSet}), + // values SET SIZE(1..MAX) OF ATTRIBUTE.&Type({IOSet}{@type}) + // } + + using (writer.PushSequence()) + { + writer.WriteInteger(0); + writer.WriteEncodedValue(_subjectName.RawData); + + // subjectPKInfo + using (writer.PushSequence()) + { + // algorithm + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.1"); + // RSA uses an explicit NULL value for parameters + writer.WriteNull(); + } + + writer.WriteBitString(publicKey); + } + + if (OtherRequestAttributes.Count > 0) + { + // attributes + using (writer.PushSetOf(new Asn1Tag(TagClass.ContextSpecific, 0))) + { + foreach (AsnEncodedData attribute in OtherRequestAttributes) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(attribute.Oid.Value); + + using (writer.PushSetOf()) + { + writer.WriteEncodedValue(attribute.RawData); + } + } + } + } + } + } + + byte[] certReqInfo = writer.Encode(); + writer.Reset(); + + // CertificationRequest ::= SEQUENCE { + // certificationRequestInfo CertificationRequestInfo, + // signatureAlgorithm AlgorithmIdentifier{{ SignatureAlgorithms }}, + // signature BIT STRING + // } + + using (writer.PushSequence()) + { + writer.WriteEncodedValue(certReqInfo); + + // signatureAlgorithm + using (writer.PushSequence()) + { + if (_rsaPadding == RSASignaturePadding.Pss) + { + if (_hashAlgorithmName != HashAlgorithmName.SHA256) + { + throw new NotSupportedException("Only SHA256 is supported with PSS padding."); + } + + writer.WriteObjectIdentifier("1.2.840.113549.1.1.10"); + + // RSASSA-PSS-params ::= SEQUENCE { + // hashAlgorithm [0] HashAlgorithm DEFAULT sha1, + // maskGenAlgorithm [1] MaskGenAlgorithm DEFAULT mgf1SHA1, + // saltLength [2] INTEGER DEFAULT 20, + // trailerField [3] TrailerField DEFAULT trailerFieldBC + // } + + using (writer.PushSequence()) + { + string digestOid = "2.16.840.1.101.3.4.2.1"; + + // hashAlgorithm + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 0))) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(digestOid); + } + } + + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 1))) + { + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.8"); + + using (writer.PushSequence()) + { + writer.WriteObjectIdentifier(digestOid); + } + } + } + + // saltLength (SHA256.Length, 32 bytes) + using (writer.PushSequence(new Asn1Tag(TagClass.ContextSpecific, 2))) + { + writer.WriteInteger(32); + } + + // trailerField 1, which is trailerFieldBC, which is the DEFAULT, + // so don't write it down. + } + } + else if (_rsaPadding == RSASignaturePadding.Pkcs1) + { + writer.WriteObjectIdentifier("1.2.840.113549.1.1.11"); + // RSA PKCS1 uses an explicit NULL value for parameters + writer.WriteNull(); + } + else + { + throw new NotSupportedException("Unsupported RSA padding."); + } + } + + byte[] signature = _rsa.SignData(certReqInfo, _hashAlgorithmName, _rsaPadding); + writer.WriteBitString(signature); + } + + return writer.Encode(); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs similarity index 96% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs index 4391fba4be..a000ea008c 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs @@ -7,7 +7,7 @@ using Microsoft.Identity.Json; #endif -namespace Microsoft.Identity.Client.ManagedIdentity +namespace Microsoft.Identity.Client.ManagedIdentity.V2 { /// /// Represents the response for a Managed Identity CSR request. diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs new file mode 100644 index 0000000000..35d6465d99 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Formats.Asn1; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.Utils; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class Csr + { + internal static string Generate(string clientId, string tenantId, CuidInfo cuid) + { + using (RSA rsa = CreateRsaKeyPair()) + { + CertificateRequest req = new CertificateRequest( + new X500DistinguishedName($"CN={clientId}, DC={tenantId}"), + rsa, + HashAlgorithmName.SHA256, + RSASignaturePadding.Pss); + + AsnWriter writer = new AsnWriter(AsnEncodingRules.DER); + writer.WriteCharacterString(UniversalTagNumber.UTF8String, JsonHelper.SerializeToJson(cuid)); + + req.OtherRequestAttributes.Add( + new AsnEncodedData( + "1.2.840.113549.1.9.7", + writer.Encode())); + + return req.CreateSigningRequestPem(); + } + } + + private static RSA CreateRsaKeyPair() + { + // TODO: use the strongest key on the machine i.e. a TPM key + RSA rsa = null; + +#if NET462 || NET472 + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new RSACng(); +#else + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif + rsa.KeySize = 2048; + return rsa; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs similarity index 97% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs index 5de9a5e490..6281fcec14 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CsrMetadata.cs @@ -7,7 +7,7 @@ using Microsoft.Identity.Json; #endif -namespace Microsoft.Identity.Client.ManagedIdentity +namespace Microsoft.Identity.Client.ManagedIdentity.V2 { /// /// Represents VM unique Ids for CSR metadata. diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs similarity index 97% rename from src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs rename to src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index bd0d2c2e4a..dc65aa72ef 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -12,7 +12,7 @@ using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Utils; -namespace Microsoft.Identity.Client.ManagedIdentity +namespace Microsoft.Identity.Client.ManagedIdentity.V2 { internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { @@ -200,7 +200,7 @@ internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : private async Task ExecuteCertificateRequestAsync( CuidInfo cuid, - string pem) + string csrPem) { var queryParams = $"cuid={JsonHelper.SerializeToJson(cuid)}&cred-api-version={ImdsV2ApiVersion}"; if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) @@ -214,8 +214,7 @@ private async Task ExecuteCertificateRequestAsync( { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } }; - var payload = new PemPayload { pem = pem }; - var body = JsonHelper.SerializeToJson(payload); + var body = $"{{\"pem\":\"{csrPem}\"}}"; IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); @@ -274,9 +273,9 @@ private async Task ExecuteCertificateRequestAsync( protected override ManagedIdentityRequest CreateRequest(string resource) { var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); - var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); + var csrPem = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); - var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csr.Pem).GetAwaiter().GetResult(); + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csrPem).GetAwaiter().GetResult(); throw new NotImplementedException(); } diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 578bb27e45..8342355663 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -94,6 +94,7 @@ + @@ -118,6 +119,7 @@ + diff --git a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs index 6d6a6cb7f2..5933d95e58 100644 --- a/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs +++ b/src/client/Microsoft.Identity.Client/Platforms/net/MsalJsonSerializerContext.cs @@ -13,6 +13,7 @@ using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Kerberos; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Client.Region; using Microsoft.Identity.Client.WsTrust; @@ -43,7 +44,6 @@ namespace Microsoft.Identity.Client.Platforms.net [JsonSerializable(typeof(CsrMetadata))] [JsonSerializable(typeof(CuidInfo))] [JsonSerializable(typeof(CertificateRequestResponse))] - [JsonSerializable(typeof(PemPayload))] [JsonSourceGenerationOptions] internal partial class MsalJsonSerializerContext : JsonSerializerContext { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs index adbc2e298d..667c5a6050 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -2,17 +2,22 @@ // Licensed under the MIT License. using System; -using Microsoft.Identity.Client.ManagedIdentity; +using System.Formats.Asn1; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.Utils; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests { /// - /// Test helper to expose CsrValidator methods for testing malformed PEM. + /// Helper class for parsing and validating Certificate Signing Request (CSR) content and structure. /// - internal static class TestCsrValidator + internal static class CsrValidator { + /// + /// Parses a PEM-encoded CSR and returns the DER bytes. + /// public static byte[] ParseCsrFromPem(string pemCsr) { if (string.IsNullOrWhiteSpace(pemCsr)) @@ -21,15 +26,13 @@ public static byte[] ParseCsrFromPem(string pemCsr) const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; const string endMarker = "-----END CERTIFICATE REQUEST-----"; - if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) - throw new ArgumentException("Invalid PEM format - missing CSR headers"); + int beginIndex = pemCsr.IndexOf(beginMarker, StringComparison.Ordinal); + int endIndex = pemCsr.IndexOf(endMarker, StringComparison.Ordinal); - int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; - int endIndex = pemCsr.IndexOf(endMarker); - - if (beginIndex >= endIndex) - throw new ArgumentException("Invalid PEM format - malformed headers"); + if (beginIndex < 0 || endIndex < 0) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + beginIndex += beginMarker.Length; string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) .Replace("\r", "").Replace("\n", "").Replace(" ", ""); @@ -42,390 +45,100 @@ public static byte[] ParseCsrFromPem(string pemCsr) throw new FormatException("Invalid Base64 content in PEM CSR"); } } - } - /// - /// Helper class for validating Certificate Signing Request (CSR) content and structure. - /// - internal static class CsrValidator - { /// /// Validates the content of a CSR PEM string against expected values. /// public static void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) { - // Parse the CSR from PEM format - var csrData = ParseCsrFromPem(pemCsr); - - // Parse the PKCS#10 structure - var csrInfo = ParsePkcs10Structure(csrData); - - // Validate subject name - ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); - - // Validate public key - ValidatePublicKey(csrInfo.PublicKey); - - // Validate CUID attribute - ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); - - // Validate signature algorithm - ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); - } - - /// - /// Parses a PEM-formatted CSR and returns the DER bytes. - /// - private static byte[] ParseCsrFromPem(string pemCsr) - { - if (string.IsNullOrWhiteSpace(pemCsr)) - throw new ArgumentException("PEM CSR cannot be null or empty"); - - const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; - const string endMarker = "-----END CERTIFICATE REQUEST-----"; - - if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) - throw new ArgumentException("Invalid PEM format - missing CSR headers"); - - int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; - int endIndex = pemCsr.IndexOf(endMarker); - - if (beginIndex >= endIndex) - throw new ArgumentException("Invalid PEM format - malformed headers"); - - string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) - .Replace("\r", "").Replace("\n", "").Replace(" ", ""); - - try + byte[] csrBytes = ParseCsrFromPem(pemCsr); + + // Parse the CSR using AsnReader + var reader = new AsnReader(csrBytes, AsnEncodingRules.DER); + var csrSequence = reader.ReadSequence(); + + // certificationRequestInfo + var certReqInfoBytes = csrSequence.PeekEncodedValue().ToArray(); + var certReqInfoReader = new AsnReader(csrSequence.ReadEncodedValue().ToArray(), AsnEncodingRules.DER); + var certReqInfoSeq = certReqInfoReader.ReadSequence(); + + // version + int version = (int)certReqInfoSeq.ReadInteger(); + Assert.AreEqual(0, version, "CSR version should be 0"); + + // subject + var subjectBytes = certReqInfoSeq.PeekEncodedValue().ToArray(); + var subject = new X500DistinguishedName(certReqInfoSeq.ReadEncodedValue().ToArray()); + string subjectString = subject.Name; + + Assert.IsTrue(subjectString.Contains($"CN={expectedClientId}"), "Client ID (CN) not found in subject"); + Assert.IsTrue(subjectString.Contains($"DC={expectedTenantId}"), "Tenant ID (DC) not found in subject"); + + // subjectPKInfo + var pkInfoReader = new AsnReader(certReqInfoSeq.ReadEncodedValue().ToArray(), AsnEncodingRules.DER); + var pkInfoSeq = pkInfoReader.ReadSequence(); + + // algorithm + var algIdSeq = pkInfoSeq.ReadSequence(); + string algOid = algIdSeq.ReadObjectIdentifier(); + Assert.AreEqual("1.2.840.113549.1.1.1", algOid, "Public key algorithm is not RSA"); + if (algIdSeq.HasData) { - return Convert.FromBase64String(base64Content); + algIdSeq.ReadNull(); } - catch (FormatException) - { - throw new FormatException("Invalid Base64 content in PEM CSR"); - } - } - /// - /// Represents parsed PKCS#10 CSR information. - /// - private class CsrInfo - { - public byte[] Subject { get; set; } - public byte[] PublicKey { get; set; } - public byte[] Attributes { get; set; } - public byte[] SignatureAlgorithm { get; set; } - } - - /// - /// Parses the PKCS#10 ASN.1 structure and extracts key components. - /// - private static CsrInfo ParsePkcs10Structure(byte[] derBytes) - { - int offset = 0; - - // Parse outer SEQUENCE (CertificationRequest) - var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); - - // Reset offset to parse the CertificationRequestInfo within the outer sequence - int infoOffset = 0; - var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); - - // Parse version (should be 0) - int versionOffset = 0; - var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); - if (version.Length != 1 || version[0] != 0x00) - throw new ArgumentException("Invalid CSR version"); - - // Parse subject - var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse SubjectPublicKeyInfo - var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); - - // Parse attributes (context-specific [0]) - var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); - - return new CsrInfo - { - Subject = subject, - PublicKey = publicKey, - Attributes = attributes, - SignatureAlgorithm = new byte[0] // Simplified for this test - }; - } - - /// - /// Parses an ASN.1 tag and returns its content. - /// - private static byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data"); - - // Check tag (if expectedTag is -1, accept any tag) - if (expectedTag != 255 && data[offset] != expectedTag) - throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); - - offset++; - - // Parse length - int length = ParseAsn1Length(data, ref offset); - - // Extract content - if (offset + length > data.Length) - throw new ArgumentException("Invalid ASN.1 length"); - - byte[] content = new byte[length]; - Array.Copy(data, offset, content, 0, length); - offset += length; - - return content; - } + // subjectPublicKey BIT STRING + var publicKeyBitString = pkInfoSeq.ReadBitString(out _); - /// - /// Parses ASN.1 length encoding. - /// - private static int ParseAsn1Length(byte[] data, ref int offset) - { - if (offset >= data.Length) - throw new ArgumentException("Unexpected end of data in length"); - - byte firstByte = data[offset++]; - - // Short form - if ((firstByte & 0x80) == 0) - return firstByte; - - // Long form - int lengthBytes = firstByte & 0x7F; - if (lengthBytes == 0) - throw new ArgumentException("Indefinite length not supported"); - - if (offset + lengthBytes > data.Length) - throw new ArgumentException("Invalid length encoding"); - - int length = 0; - for (int i = 0; i < lengthBytes; i++) - { - length = (length << 8) | data[offset++]; - } - - return length; - } + // Parse the RSAPublicKey structure from the BIT STRING (SEQUENCE of modulus, exponent) + var rsaKeyReader = new AsnReader(publicKeyBitString, AsnEncodingRules.DER); + var rsaKeySeq = rsaKeyReader.ReadSequence(); + byte[] modulus = rsaKeySeq.ReadIntegerBytes().ToArray(); + byte[] exponent = rsaKeySeq.ReadIntegerBytes().ToArray(); - /// - /// Validates the subject name contains the expected client ID and tenant ID. - /// - private static void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) - { - // Subject is already a SEQUENCE of RDNs - int offset = 0; - bool foundClientId = false; - bool foundTenantId = false; - - // Parse each RDN (Relative Distinguished Name) directly from subjectBytes - while (offset < subjectBytes.Length) - { - var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET - - int rdnOffset = 0; - var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE - - // Parse OID and value - int attrOffset = 0; - var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID - var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type - - string stringValue = System.Text.Encoding.UTF8.GetString(value); - - // Check for CN (commonName) OID: 2.5.4.3 - if (IsOid(oid, new int[] { 2, 5, 4, 3 })) - { - Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); - foundClientId = true; - } - // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 - else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) - { - Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); - foundTenantId = true; - } - } - - Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); - Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); - } + // Validate modulus length (2048 bits = 256 bytes, may have leading zero) + Assert.IsTrue(modulus.Length == 256 || modulus.Length == 257, $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); - /// - /// Validates the public key is a valid RSA key. - /// - private static void ValidatePublicKey(byte[] publicKeyBytes) - { - // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content - int offset = 0; - - // Parse algorithm identifier - var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); - - // Parse public key bit string - var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); - - // Validate algorithm is RSA (1.2.840.113549.1.1.1) - int algOffset = 0; - var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); - Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), - "Public key algorithm is not RSA"); - - // Skip the unused bits byte in bit string - if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) - throw new ArgumentException("Invalid public key bit string"); - - // Parse RSA public key (skip unused bits byte) - byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; - Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); - - int rsaOffset = 0; - var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); - - rsaOffset = 0; - var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); - - // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) - Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, - $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); - // Validate exponent (commonly 65537 = 0x010001) Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); - } - /// - /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. - /// Note: VmId is required, VmssId is optional and will be omitted if null/empty. - /// - private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) - { - // Attributes is a SET of attributes - // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) - - int offset = 0; - bool foundCuid = false; - - // Parse each attribute in the SET - while (offset < attributesBytes.Length) + // attributes [0] (optional) + if (certReqInfoSeq.HasData) { - var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); - - int attrOffset = 0; - var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); - var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values - - // Check for challengePassword OID: 1.2.840.113549.1.9.7 - if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + var attrTag = new Asn1Tag(TagClass.ContextSpecific, 0); + if (certReqInfoSeq.PeekTag().HasSameClassAndValue(attrTag)) { - // Parse the value from the SET (should be one value) - int valueOffset = 0; - var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type - - string cuidValue = System.Text.Encoding.ASCII.GetString(value); - - // Build expected CUID value as JSON - string expectedCuidValue = BuildExpectedCuidJson(expectedCuid); - - Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute JSON value does not match expected"); - foundCuid = true; - break; + var attrSetReader = certReqInfoSeq.ReadSetOf(attrTag); + bool foundCuid = false; + while (attrSetReader.HasData) + { + var attrSeq = attrSetReader.ReadSequence(); + string oid = attrSeq.ReadObjectIdentifier(); + if (oid == "1.2.840.113549.1.9.7") // challengePassword + { + var valueSet = attrSeq.ReadSetOf(); + while (valueSet.HasData) + { + string cuidJson = valueSet.ReadCharacterString(UniversalTagNumber.UTF8String); + string expectedCuidJson = JsonHelper.SerializeToJson(expectedCuid); + Assert.AreEqual(expectedCuidJson, cuidJson, "CUID attribute JSON value does not match expected"); + foundCuid = true; + } + } + } + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); } } - - Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); - } - /// - /// Builds the expected CUID JSON string for validation using JsonHelper. - /// - private static string BuildExpectedCuidJson(CuidInfo expectedCuid) - { - return JsonHelper.SerializeToJson(expectedCuid); - } - - /// - /// Validates the signature algorithm is SHA256withRSA. - /// - private static void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) - { - // For this test, we'll just verify that signature algorithm exists - // Full validation would require parsing the outer CSR structure - // which is more complex for this unit test scenario - Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); - } + // signatureAlgorithm + var sigAlgSeq = csrSequence.ReadSequence(); + string sigAlgOid = sigAlgSeq.ReadObjectIdentifier(); + Assert.AreEqual("1.2.840.113549.1.1.10", sigAlgOid, "Signature algorithm is not RSASSA-PSS (SHA256withRSA/PSS)"); - /// - /// Checks if the given OID bytes match the expected OID components. - /// - private static bool IsOid(byte[] oidBytes, int[] expectedOid) - { - if (expectedOid.Length < 2) - return false; - - var expectedBytes = EncodeOid(expectedOid); - - if (oidBytes.Length != expectedBytes.Length) - return false; - - for (int i = 0; i < oidBytes.Length; i++) - { - if (oidBytes[i] != expectedBytes[i]) - return false; - } - - return true; - } + // signature + csrSequence.ReadBitString(out _); - /// - /// Encodes an OID from integer components to bytes (simplified version). - /// - private static byte[] EncodeOid(int[] oid) - { - if (oid.Length < 2) - throw new ArgumentException("OID must have at least 2 components"); - - var result = new System.Collections.Generic.List(); - - // First two components are encoded as (first * 40 + second) - result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); - - // Remaining components - for (int i = 2; i < oid.Length; i++) - { - result.AddRange(EncodeOidComponent(oid[i])); - } - - return result.ToArray(); - } - - /// - /// Encodes a single OID component using variable-length encoding. - /// - private static byte[] EncodeOidComponent(int value) - { - if (value == 0) - return new byte[] { 0x00 }; - - var bytes = new System.Collections.Generic.List(); - int temp = value; - - bytes.Insert(0, (byte)(temp & 0x7F)); - temp >>= 7; - - while (temp > 0) - { - bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); - temp >>= 7; - } - - return bytes.ToArray(); + Assert.IsFalse(csrSequence.HasData, "Extra data found after CSR structure"); } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 312d43ec74..afd52214f9 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -7,6 +7,7 @@ using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -146,71 +147,10 @@ public void TestCsrGeneration() }; // Generate CSR - var csr = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + var csrPem = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); // Validate the CSR contents using the helper - CsrValidator.ValidateCsrContent(csr.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); - } - - [DataTestMethod] - [DataRow(null, TestConstants.TenantId)] - [DataRow("", TestConstants.TenantId)] - [DataRow(TestConstants.ClientId, null)] - [DataRow(TestConstants.ClientId, "")] - public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) - { - var cuid = new CuidInfo - { - VmId = TestConstants.VmId, - //VmssId = TestConstants.VmssId - }; - - Assert.ThrowsException(() => - Csr.Generate(clientId, tenantId, cuid)); - } - - [TestMethod] - public void TestCsrGeneration_NullCuid() - { - // Test with null CUID - Assert.ThrowsException(() => - Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); - } - - [DataTestMethod] - [DataRow(null, TestConstants.VmssId)] - [DataRow("", TestConstants.VmssId)] - public void TestCsrGeneration_InvalidVmId(string vmId, string vmssId) - { - var cuid = new CuidInfo - { - VmId = vmId, - //VmssId = vmssId - }; - - // Should throw ArgumentException since VmId is required - Assert.ThrowsException(() => - Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); - } - - [DataTestMethod] - [DataRow(TestConstants.VmId, null)] - [DataRow(TestConstants.VmId, "")] - public void TestCsrGeneration_OptionalVmssId(string vmId, string vmssId) - { - var cuid = new CuidInfo - { - VmId = vmId, - //VmssId = vmssId - }; - - // Should succeed since VmssId is optional (VmId is provided and valid) - var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); - Assert.IsNotNull(csrRequest); - Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); - - // Validate the CSR contents - this should handle null/empty vmssId gracefully - CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csrPem, TestConstants.ClientId, TestConstants.TenantId, cuid); } [TestMethod] @@ -218,7 +158,7 @@ public void TestCsrGeneration_MalformedPem_FormatException() { string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; Assert.ThrowsException(() => - TestCsrValidator.ParseCsrFromPem(malformedPem)); + CsrValidator.ParseCsrFromPem(malformedPem)); } [DataTestMethod] @@ -228,7 +168,7 @@ public void TestCsrGeneration_MalformedPem_FormatException() public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) { Assert.ThrowsException(() => - TestCsrValidator.ParseCsrFromPem(malformedPem)); + CsrValidator.ParseCsrFromPem(malformedPem)); } } } From 755cf6f5590ca27290d1a0ca2e735587c4d02461 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 22 Aug 2025 18:02:17 -0400 Subject: [PATCH 27/41] Implemented some feedback --- .../V2/ImdsV2ManagedIdentitySource.cs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index f686e54c5c..6f3d8fe320 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -278,17 +278,11 @@ protected override ManagedIdentityRequest CreateRequest(string resource) var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.CuId, csrPem).GetAwaiter().GetResult(); - ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{clientCredentialRequestResponse.RegionalTokenUrl}/{clientCredentialRequestResponse.TenantId}{AcquireEntraTokenPath}")); - + ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.RegionalTokenUrl}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); - - request.BodyParameters.Add("grant_type", clientCredentialRequestResponse.ClientCredential); + request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); + request.BodyParameters.Add("grant_type", certificateRequestResponse.ClientCredential); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); - if (clientCredentialRequestResponse.ClientId != null) - { - request.BodyParameters.Add("client_id", clientCredentialRequestResponse.ClientId); - } - request.RequestType = RequestType.Imds; return request; From 14d05f15e30e3a90022b210049ca6ec5d49fd322 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 22 Aug 2025 18:58:17 -0400 Subject: [PATCH 28/41] additional improvements --- .../ManagedIdentity/ImdsManagedIdentitySource.cs | 3 ++- .../ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs | 5 +++-- .../Core/Mocks/MockHelpers.cs | 10 ++++++---- .../ManagedIdentityTests/ImdsV2Tests.cs | 4 +++- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index af6be6cf81..d1c98db415 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -18,7 +18,8 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsManagedIdentitySource : AbstractManagedIdentity { // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http - private const string DefaultImdsBaseEndpoint= "http://169.254.169.254"; + // used in unit tests too + public const string DefaultImdsBaseEndpoint= "http://169.254.169.254"; private const string ImdsTokenPath = "/metadata/identity/oauth2/token"; public const string ImdsApiVersion = "2018-02-01"; diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 6f3d8fe320..d845c9a499 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -16,10 +16,11 @@ namespace Microsoft.Identity.Client.ManagedIdentity.V2 { internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { + // used in unit tests public const string ImdsV2ApiVersion = "2.0"; private const string CsrMetadataPath = "/metadata/identity/getplatformmetadata"; - private const string CertificateRequestPath = "/metadata/identity/issuecredential"; - private const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; + public const string CertificateRequestPath = "/metadata/identity/issuecredential"; + public const string AcquireEntraTokenPath = "/oauth2/v2.0/token"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 235758eeee..87e6f7b916 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -8,8 +8,10 @@ using System.Net; using System.Net.Http; using System.Net.Http.Headers; +using System.Web; using Microsoft.Identity.Client; using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Client.Utils; using Microsoft.Identity.Test.Unit; @@ -626,13 +628,13 @@ public static MockHttpMessageHandler MockCsrResponseFailure() return MockCsrResponse(HttpStatusCode.BadRequest); } - public static MockHttpMessageHandler MockClientCredentialResponse() + public static MockHttpMessageHandler MockCertificateRequestResponse() { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - expectedQueryParams.Add("cid", "%7B%22vmid%22:%22fake_vmid%22,%22vmssid%22:%22fake_vmssid%22%7D"); + expectedQueryParams.Add("cuid", "%7B%22vmId%22:%22fake_vmId%22,%22vmssId%22:%22fake_vmssId%22%7D"); //expectedQueryParams.Add("uaid", "fake_client_id"); - expectedQueryParams.Add("api-version", "2018-02-01"); + expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); expectedRequestHeaders.Add("Metadata", "true"); string content = @@ -647,7 +649,7 @@ public static MockHttpMessageHandler MockClientCredentialResponse() var handler = new MockHttpMessageHandler() { - ExpectedUrl = "http://169.254.169.254/metadata/identity/issuecredential", + ExpectedUrl = $"{ImdsManagedIdentitySource.DefaultImdsBaseEndpoint}{ImdsV2ManagedIdentitySource.CertificateRequestPath}", ExpectedMethod = HttpMethod.Post, ExpectedQueryParams = expectedQueryParams, ExpectedRequestHeaders = expectedRequestHeaders, diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index b06cde64c8..8d482c319d 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -41,7 +41,9 @@ public async Task ImdsV2HappyPathAsync() httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - httpManager.AddMockHandler(MockHelpers.MockClientCredentialResponse()); + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); + // TODO: add a mock handler for acquiring the entra token over an mTLS channel + //httpManager.AddMockHandler(); // TODO: finish this. everything has been tested to this point. /*MockHttpMessageHandler mockHandler = httpManager.AddManagedIdentityMockHandler( From 69e714ab0874d53a3d0f169f66c5a05f127105eb Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 12:28:32 -0400 Subject: [PATCH 29/41] fixed bad rebase --- .../ManagedIdentity/ImdsManagedIdentitySource.cs | 2 +- .../ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index 5f0f283983..ecc7efab1f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -18,7 +18,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsManagedIdentitySource : AbstractManagedIdentity { // IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http - // used in unit tests too + // used in unit tests as well public const string DefaultImdsBaseEndpoint= "http://169.254.169.254"; private const string ImdsTokenPath = "/metadata/identity/oauth2/token"; public const string ImdsApiVersion = "2018-02-01"; diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index edaa719c80..fb42296713 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -123,9 +123,7 @@ private static bool ValidateCsrMetadataResponse( * "1556" // index 1: captured group (\d+) * ] */ - // Imds bug: headers are missing - // TODO: uncomment this when the bug is fixed - /*string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; + string serverHeader = response.HeadersAsDictionary.TryGetValue("server", out var value) ? value : null; if (serverHeader == null) { if (probeMode) @@ -160,7 +158,7 @@ private static bool ValidateCsrMetadataResponse( null, (int)response.StatusCode); } - }*/ + } return true; } From 4219044c05977f54a98aa63c018ac4a04fc77079 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 12:34:16 -0400 Subject: [PATCH 30/41] Fixed bad rebase --- .../ManagedIdentity/V2/CertificateRequestResponse.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs index 51ac149472..5e84000054 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/CertificateRequestResponse.cs @@ -48,8 +48,6 @@ public static void Validate(CertificateRequestResponse certificateRequestRespons ManagedIdentitySource.ImdsV2, (int)HttpStatusCode.OK); } - - return true; } } } From 3bab6e445b8c0e15ce0e4027256998426029bf13 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 12:37:31 -0400 Subject: [PATCH 31/41] Fixed bad rebase --- .../ManagedIdentityTests/ImdsV2Tests.cs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 77137922c0..a4af522c18 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -114,9 +114,7 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() } } - // Imds bug: headers are missing - // TODO: uncomment this when the bug is fixed - /*[TestMethod] + [TestMethod] public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { using (var httpManager = new MockHttpManager()) @@ -148,7 +146,7 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidVersion() var miSource = await (managedIdentityApp as ManagedIdentityApplication).GetManagedIdentitySourceAsync().ConfigureAwait(false); Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } - }*/ + } [TestMethod] public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() From 6dacdf5ec1353b33d02095fba310e25d4509d3be Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 12:45:13 -0400 Subject: [PATCH 32/41] Adjusted variable names after rebase --- .../ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index fb42296713..60d8714833 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -256,10 +256,10 @@ protected override async Task CreateRequestAsync(string var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); - ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.RegionalTokenUrl}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); + ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId); - request.BodyParameters.Add("grant_type", certificateRequestResponse.ClientCredential); + request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); request.RequestType = RequestType.Imds; From 0df6e1e08bcee823a8962249e01eb4ae907512ef Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 12:54:28 -0400 Subject: [PATCH 33/41] wrote the skeleton for the mTLS cert creation --- .../ManagedIdentity/AbstractManagedIdentity.cs | 4 ++-- .../ManagedIdentity/ManagedIdentityRequest.cs | 10 +++++++++- .../ManagedIdentity/V2/Csr.cs | 4 ++-- .../ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs | 8 +++++++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs index 276fe67c78..67434999cb 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs @@ -82,7 +82,7 @@ public virtual async Task AuthenticateAsync( method: HttpMethod.Get, logger: _requestContext.Logger, doNotThrow: true, - mtlsCertificate: null, + mtlsCertificate: request.MtlsCertificate, validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, retryPolicy: retryPolicy).ConfigureAwait(false); @@ -97,7 +97,7 @@ public virtual async Task AuthenticateAsync( method: HttpMethod.Post, logger: _requestContext.Logger, doNotThrow: true, - mtlsCertificate: null, + mtlsCertificate: request.MtlsCertificate, validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, retryPolicy: retryPolicy) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs index c5b9af2b73..6a7161d2c0 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Security.Cryptography.X509Certificates; using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.OAuth2; @@ -26,7 +27,13 @@ internal class ManagedIdentityRequest public RequestType RequestType { get; set; } - public ManagedIdentityRequest(HttpMethod method, Uri endpoint, RequestType requestType = RequestType.ManagedIdentityDefault) + public X509Certificate2 MtlsCertificate { get; set; } + + public ManagedIdentityRequest( + HttpMethod method, + Uri endpoint, + RequestType requestType = RequestType.ManagedIdentityDefault, + X509Certificate2 mtlsCertificate = null) { Method = method; _baseEndpoint = endpoint; @@ -34,6 +41,7 @@ public ManagedIdentityRequest(HttpMethod method, Uri endpoint, RequestType reque BodyParameters = new Dictionary(); QueryParameters = new Dictionary(); RequestType = requestType; + MtlsCertificate = mtlsCertificate; } public Uri ComputeUri() diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs index 3f3b1175a3..3c90605140 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/Csr.cs @@ -10,7 +10,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity.V2 { internal class Csr { - internal static string Generate(string clientId, string tenantId, CuidInfo cuid) + internal static (string csrPem, RSA privateKey) Generate(string clientId, string tenantId, CuidInfo cuid) { using (RSA rsa = CreateRsaKeyPair()) { @@ -28,7 +28,7 @@ internal static string Generate(string clientId, string tenantId, CuidInfo cuid) "1.3.6.1.4.1.311.90.2.10", writer.Encode())); - return req.CreateSigningRequestPem(); + return (req.CreateSigningRequestPem(), rsa); } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 60d8714833..1c61c67136 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -252,9 +252,14 @@ private async Task ExecuteCertificateRequestAsync(st protected override async Task CreateRequestAsync(string resource) { var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false); - var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); + var (csr, privateKey) = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); + + // transform certificateRequestResponse.ClientCredential to x509 + var mtlsCertificate = CreateCertificateWithPrivateKey( // TODO: implement this method + certificateRequestResponse.Certificate, + privateKey); ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); @@ -262,6 +267,7 @@ protected override async Task CreateRequestAsync(string request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); request.RequestType = RequestType.Imds; + request.MtlsCertificate = mtlsCertificate; return request; } From a66a933bc158748cc927d83d1f9da601073c6899 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Wed, 27 Aug 2025 13:06:30 -0400 Subject: [PATCH 34/41] adjusted unit test based on new code --- .../ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs | 6 +++--- .../ManagedIdentityTests/ImdsV2Tests.cs | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 1c61c67136..bb42509392 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -257,9 +257,9 @@ protected override async Task CreateRequestAsync(string var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); // transform certificateRequestResponse.ClientCredential to x509 - var mtlsCertificate = CreateCertificateWithPrivateKey( // TODO: implement this method + /*var mtlsCertificate = CreateCertificateWithPrivateKey( // TODO: implement this method certificateRequestResponse.Certificate, - privateKey); + privateKey);*/ ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); @@ -267,7 +267,7 @@ protected override async Task CreateRequestAsync(string request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); request.RequestType = RequestType.Imds; - request.MtlsCertificate = mtlsCertificate; + //request.MtlsCertificate = mtlsCertificate; return request; } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index a4af522c18..7f0d9092b2 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -194,8 +194,8 @@ public void TestCsrGeneration_OnlyVmId() VmId = TestConstants.VmId }; - var csrPem = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); - CsrValidator.ValidateCsrContent(csrPem, TestConstants.ClientId, TestConstants.TenantId, cuid); + var (csr, _) = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csr, TestConstants.ClientId, TestConstants.TenantId, cuid); } [TestMethod] @@ -207,8 +207,8 @@ public void TestCsrGeneration_VmIdAndVmssId() VmssId = TestConstants.VmssId }; - var csrPem = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); - CsrValidator.ValidateCsrContent(csrPem, TestConstants.ClientId, TestConstants.TenantId, cuid); + var (csr, _) = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + CsrValidator.ValidateCsrContent(csr, TestConstants.ClientId, TestConstants.TenantId, cuid); } [TestMethod] From 67cc4a675e3261f1c4288f24a171fedf04358ee4 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 28 Aug 2025 15:41:07 -0400 Subject: [PATCH 35/41] Implemented mTLS --- .../V2/ImdsV2ManagedIdentitySource.cs | 135 ++++++++++- .../Core/Mocks/MockHelpers.cs | 9 +- .../ManagedIdentityTests/ImdsV2Tests.cs | 216 ++++++++++++++++-- 3 files changed, 335 insertions(+), 25 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index bb42509392..abd4e9c847 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -5,6 +5,8 @@ using System.Collections.Generic; using System.Net; using System.Net.Http; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; @@ -256,10 +258,10 @@ protected override async Task CreateRequestAsync(string var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); - // transform certificateRequestResponse.ClientCredential to x509 - /*var mtlsCertificate = CreateCertificateWithPrivateKey( // TODO: implement this method + // transform certificateRequestResponse.Certificate to x509 with private key + var mtlsCertificate = AttachPrivateKeyToCert( certificateRequestResponse.Certificate, - privateKey);*/ + privateKey); ManagedIdentityRequest request = new(HttpMethod.Post, new Uri($"{certificateRequestResponse.MtlsAuthenticationEndpoint}/{certificateRequestResponse.TenantId}{AcquireEntraTokenPath}")); request.Headers.Add("x-ms-client-request-id", _requestContext.CorrelationId.ToString()); @@ -267,11 +269,136 @@ protected override async Task CreateRequestAsync(string request.BodyParameters.Add("grant_type", certificateRequestResponse.Certificate); request.BodyParameters.Add("scope", "https://management.azure.com/.default"); request.RequestType = RequestType.Imds; - //request.MtlsCertificate = mtlsCertificate; + request.MtlsCertificate = mtlsCertificate; return request; } + /// + /// Attaches a private key to a certificate for use in mTLS authentication. + /// + /// The certificate in PEM format + /// The RSA private key to attach + /// An X509Certificate2 with the private key attached + /// Thrown when certificatePem or privateKey is null + /// Thrown when certificatePem is not a valid PEM certificate + /// Thrown when the certificate cannot be parsed + internal X509Certificate2 AttachPrivateKeyToCert(string certificatePem, RSA privateKey) + { + if (string.IsNullOrEmpty(certificatePem)) + throw new ArgumentNullException(nameof(certificatePem)); + if (privateKey == null) + throw new ArgumentNullException(nameof(privateKey)); + + X509Certificate2 certificate; + +#if NET8_0_OR_GREATER + // .NET 8.0+ has direct PEM parsing support + certificate = X509Certificate2.CreateFromPem(certificatePem); + // Attach the private key and return a new certificate instance + return certificate.CopyWithPrivateKey(privateKey); +#else + // .NET Framework 4.7.2 and .NET Standard 2.0 - manual PEM parsing and private key attachment + certificate = ParseCertificateFromPem(certificatePem); + return AttachPrivateKeyToOlderFrameworks(certificate, privateKey); +#endif + } + +#if !NET8_0_OR_GREATER + /// + /// Parses a certificate from PEM format for older .NET versions. + /// + /// The certificate in PEM format + /// An X509Certificate2 instance + /// Thrown when the PEM format is invalid + /// Thrown when the Base64 content cannot be decoded + internal static X509Certificate2 ParseCertificateFromPem(string certificatePem) + { + const string CertBeginMarker = "-----BEGIN CERTIFICATE-----"; + const string CertEndMarker = "-----END CERTIFICATE-----"; + + int startIndex = certificatePem.IndexOf(CertBeginMarker, StringComparison.Ordinal); + if (startIndex == -1) + { + throw new ArgumentException("Invalid PEM format: missing BEGIN CERTIFICATE marker", nameof(certificatePem)); + } + + startIndex += CertBeginMarker.Length; + int endIndex = certificatePem.IndexOf(CertEndMarker, startIndex, StringComparison.Ordinal); + if (endIndex == -1) + { + throw new ArgumentException("Invalid PEM format: missing END CERTIFICATE marker", nameof(certificatePem)); + } + + string base64Content = certificatePem.Substring(startIndex, endIndex - startIndex) + .Replace("\r", "") + .Replace("\n", "") + .Replace(" ", ""); + + if (string.IsNullOrEmpty(base64Content)) + { + throw new ArgumentException("Invalid PEM format: no certificate content found", nameof(certificatePem)); + } + + try + { + byte[] certBytes = Convert.FromBase64String(base64Content); + return new X509Certificate2(certBytes); + } + catch (FormatException ex) + { + throw new FormatException("Invalid PEM format: certificate content is not valid Base64", ex); + } + } + + /// + /// Attaches a private key to a certificate for older .NET Framework versions. + /// This method uses the older RSACng approach for .NET Framework 4.7.2 and .NET Standard 2.0. + /// + /// The certificate without private key + /// The RSA private key to attach + /// An X509Certificate2 with the private key attached + /// Thrown when private key attachment fails + internal X509Certificate2 AttachPrivateKeyToOlderFrameworks(X509Certificate2 certificate, RSA privateKey) + { + try + { + // For older frameworks, we need to use the legacy approach with RSACryptoServiceProvider + // First, export the RSA parameters from the provided private key + var parameters = privateKey.ExportParameters(includePrivateParameters: true); + + // Create a new RSACryptoServiceProvider with the correct key size + int keySize = parameters.Modulus.Length * 8; + var rsaProvider = new RSACryptoServiceProvider(keySize); + + try + { + // Import the parameters into the new provider + rsaProvider.ImportParameters(parameters); + + // Create a new certificate instance from the raw data + var certWithPrivateKey = new X509Certificate2(certificate.RawData); + + // Assign the private key using the legacy property + certWithPrivateKey.PrivateKey = rsaProvider; + + return certWithPrivateKey; + } + catch + { + // Clean up the RSA provider if something goes wrong + rsaProvider?.Dispose(); + throw; + } + } + catch (Exception ex) + { + throw new NotSupportedException( + "Failed to attach private key to certificate on this .NET Framework version.", ex); + } + } +#endif + private static string ImdsV2QueryParamsHelper(RequestContext requestContext) { var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 07a59cbab2..37bbf7d9bd 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -632,7 +632,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse() { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - expectedQueryParams.Add("cuid", "%7B%22vmId%22:%22fake_vmId%22,%22vmssId%22:%22fake_vmssId%22%7D"); + expectedQueryParams.Add("cuid", "%7B%22vmId%22:%22fake_vmId%22"); //expectedQueryParams.Add("uaid", "fake_client_id"); expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); expectedRequestHeaders.Add("Metadata", "true"); @@ -641,10 +641,9 @@ public static MockHttpMessageHandler MockCertificateRequestResponse() "{" + "\"client_id\": \"fake_client_id\"," + "\"tenant_id\": \"fake_tenant_id\"," + - "\"client_credential\": \"fake_client_credential\"," + - "\"regional_token_url\": \"fake_regional_token_url\"," + - "\"expires_in\": 3600," + - "\"refresh_in\": 1800" + + "\"certificate\": \"fake_certificate\"," + + "\"identity_type\": \"fake_identity_type\"," + + "\"mtls_authentication_endpoint\": \"fake_mtls_authentication_endpoint\"," + "}"; var handler = new MockHttpMessageHandler() diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 7f0d9092b2..4e7389092a 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -2,7 +2,9 @@ // Licensed under the MIT License. using System; +using System.Drawing; using System.Net; +using System.Security.Cryptography; using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; @@ -20,16 +22,35 @@ public class ImdsV2Tests : TestBase { private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); - [TestMethod] + // Test constants for certificate testing - using a real self-signed certificate + private const string ValidPemCertificate = @"-----BEGIN CERTIFICATE----- +MIIDUTCCAjmgAwIBAgIUPS20Ik/lV4SSwHHHJGPSlG7j5SgwDQYJKoZIhvcNAQEL +BQAwNzEWMBQGA1UEAwwNVW5pdFRlc3REdW1teTEQMA4GA1UECgwHVGVzdE9yZzEL +MAkGA1UEBhMCVVMwIBcNMjUwODI4MTcxMTA3WhgPMjI5OTA2MTIxNzExMDdaMDcx +FjAUBgNVBAMMDVVuaXRUZXN0RHVtbXkxEDAOBgNVBAoMB1Rlc3RPcmcxCzAJBgNV +BAYTAlVTMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwFX8Gqz2g4Hf +dRhrNiP8oNiZ4IwO4bra9wdCR03PEKgYv1GL1Uj0OfhKSt+8WLng43da1p3jBh2P +79IRdRLLJFX4LEJaPWW2/qUCRBpA4eMmSEBRSt1hYGtMNaKdBtxDpOxCBRpofV7Z +PPTrg682ZHAlZ5K5PK9mWfRzV1C/NmSg8FtnD24VWrdkh1waqt40OzrE16JzmPpu +2YDfXilM3G5Zq4uxHXQVCrmchBSVf7frsz+LSnMU1kn45AqDjsqufxH5+CDOtFvM +R7794+HKOdzl20U+npfbtVGKIfcWh+kRcZyrLj6DER09ehVz8VWLYgntY+8riDcl +UAfGh0RNswIDAQABo1MwUTAdBgNVHQ4EFgQUbR0id2PPztRSAoggeu0eqNFwtTAw +HwYDVR0jBBgwFoAUbR0id2PPztRSAoggeu0eqNFwtTAwDwYDVR0TAQH/BAUwAwEB +/zANBgkqhkiG9w0BAQsFAAOCAQEAje7eY+MtaBo0TmeF6fM14H5MtD7cYqdFVyIa +KeVWOxwNDtwbwRyfcDlkgcXK8gLeIZA1MNBY/juTx6qy8RsHPdNSTImDVw3t7guq +2CqrA+tqU5E+wah+XzltIvbjqTvRV/20FccfcXAkyM/aWl3WHNkFYNSziT+Ug3QQ +qPABEWvXOjo4BEgrCmQJSIprLgjtfjFSK/LS/VDpRqsSa+3mmx/Dw4FY3rfEqKzv +4RPSFxE8uF/05ByoIaAJZ2JcffDZW8PI5+qwsNatCsypyRADJE1jXLzqZnFFBLW7 +dj80Qbs0xLeK0U/Aq1kFf0stgdwbDoHaJj9Q4TlSHZuI0TnjSg== +-----END CERTIFICATE-----"; + public async Task ImdsV2HappyPathAsync() { using (var httpManager = new MockHttpManager()) { - /*ManagedIdentityId managedIdentityId = userAssignedId == null - ? ManagedIdentityId.SystemAssigned - : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); - var miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) - .WithHttpManager(httpManager);*/ + //ManagedIdentityId managedIdentityId = userAssignedId == null + // ? ManagedIdentityId.SystemAssigned + // : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory); @@ -43,16 +64,7 @@ public async Task ImdsV2HappyPathAsync() httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); // TODO: add a mock handler for acquiring the entra token over an mTLS channel - //httpManager.AddMockHandler(); - - // TODO: finish this. everything has been tested to this point. - /*MockHttpMessageHandler mockHandler = httpManager.AddManagedIdentityMockHandler( - "MachineLearningEndpoint", - ManagedIdentityTests.Resource, - MockHelpers.GetMsiSuccessfulResponse(), - ManagedIdentitySource.ImdsV2//, - // userAssignedId: userAssignedId, - // userAssignedIdentityId);*/ + //httpManager.AddMockHandler() // this will fail, see TODO above var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) @@ -228,5 +240,177 @@ public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem Assert.ThrowsException(() => CsrValidator.ParseCsrFromPem(malformedPem)); } + + #region AttachPrivateKeyToCert Tests + + [TestMethod] + public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() + { + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + // For this test, we just want to verify that the method doesn't crash + // The actual certificate/private key matching isn't critical for the unit test + var exception = Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(ValidPemCertificate, rsa)); + + // The test should fail with a NotSupportedException because the RSA key doesn't match + // the certificate, but this validates that the method is working correctly + Assert.AreEqual( + "Failed to attach private key to certificate on this .NET Framework version.", + exception.Message); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_NullCertificatePem_ThrowsArgumentNullException() + { + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(null, rsa)); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_EmptyCertificatePem_ThrowsArgumentNullException() + { + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert("", rsa)); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() + { + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(ValidPemCertificate, null)); + } + + [TestMethod] + public void AttachPrivateKeyToCert_InvalidPemFormat_ThrowsArgumentException() + { + const string InvalidPemNoCertMarker = @"This is not a valid PEM certificate"; + + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(InvalidPemNoCertMarker, rsa)); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_MissingBeginMarker_ThrowsArgumentException() + { + const string InvalidPemMissingBeginMarker = @"MIICXTCCAUWgAwIBAgIJAKPiQh26MIuPMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +-----END CERTIFICATE-----"; + + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(InvalidPemMissingBeginMarker, rsa)); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_MissingEndMarker_ThrowsArgumentException() + { + const string InvalidPemMissingEndMarker = @"-----BEGIN CERTIFICATE----- +MIICXTCCAUWgAwIBAgIJAKPiQh26MIuPMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV"; + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(InvalidPemMissingEndMarker, rsa)); + } + } + + [TestMethod] + public void AttachPrivateKeyToCert_BadBase64Content_ThrowsFormatException() + { + const string InvalidPemBadBase64 = @"-----BEGIN CERTIFICATE----- +Invalid@#$%Base64Content! +-----END CERTIFICATE-----"; + + using var httpManager = new MockHttpManager(); + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + var managedIdentityApp = miBuilder.BuildConcrete(); + + var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); + + using (RSA rsa = RSA.Create()) + { + Assert.ThrowsException(() => + imdsV2Source.AttachPrivateKeyToCert(InvalidPemBadBase64, rsa)); + } + } + + #endregion } } From ca6f1d606125e34e6af35ed5fe47c4dac552c7d6 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 28 Aug 2025 15:42:07 -0400 Subject: [PATCH 36/41] Removed un-used imports --- global.json | 2 +- tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs | 1 - .../ManagedIdentityTests/ImdsV2Tests.cs | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/global.json b/global.json index 66e4a5c8a7..e5135e9ff3 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "8.0.404", + "version": "9.0.0", "rollForward": "latestFeature" } } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 37bbf7d9bd..ceb8bc21bd 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -8,7 +8,6 @@ using System.Net; using System.Net.Http; using System.Net.Http.Headers; -using System.Web; using Microsoft.Identity.Client; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 4e7389092a..dd4e30cc2e 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Drawing; using System.Net; using System.Security.Cryptography; using System.Threading.Tasks; From 3fc3ecee84fa19c466f8ca4fe3a7d66b61a37ad0 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 28 Aug 2025 15:42:51 -0400 Subject: [PATCH 37/41] Undo changes to global.json --- global.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/global.json b/global.json index e5135e9ff3..66e4a5c8a7 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "9.0.0", + "version": "8.0.404", "rollForward": "latestFeature" } } From 131b8cdf4745c1dc5995a1c443c391bda7107f1d Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 28 Aug 2025 18:08:40 -0400 Subject: [PATCH 38/41] Implemented unit test + helpers --- .../AppConfig/ApplicationConfiguration.cs | 2 + .../BaseAbstractApplicationBuilder.cs | 18 +++++++ .../ManagedIdentity/V2/DefaultCsrFactory.cs | 15 ++++++ .../ManagedIdentity/V2/ICsrFactory.cs | 12 +++++ .../V2/ImdsV2ManagedIdentitySource.cs | 2 +- .../Microsoft.Identity.Client.csproj | 7 +++ .../Core/Mocks/MockHelpers.cs | 5 +- .../Core/Mocks/MockHttpManagerExtensions.cs | 9 ++++ .../TestConstants.cs | 40 +++++++++++++++ .../Helpers/TestCsrFactory.cs | 37 ++++++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 49 +++++++------------ 11 files changed, 160 insertions(+), 36 deletions(-) create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs create mode 100644 tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs diff --git a/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs b/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs index ab19425b9e..aa23fd7fd3 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs @@ -17,6 +17,7 @@ using Microsoft.Identity.Client.Internal.Broker; using Microsoft.Identity.Client.Internal.ClientCredential; using Microsoft.Identity.Client.Kerberos; +using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.PlatformsCommon.Interfaces; using Microsoft.Identity.Client.UI; using Microsoft.IdentityModel.Abstractions; @@ -126,6 +127,7 @@ public string ClientVersion public Func> AppTokenProvider; internal IRetryPolicyFactory RetryPolicyFactory { get; set; } + internal ICsrFactory CsrFactory { get; set; } #region ClientCredentials diff --git a/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs b/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs index b60ae2dbe0..a1c0d6c5f1 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs @@ -15,6 +15,7 @@ using Microsoft.IdentityModel.Abstractions; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Http.Retry; +using Microsoft.Identity.Client.ManagedIdentity.V2; #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; @@ -39,6 +40,12 @@ internal BaseAbstractApplicationBuilder(ApplicationConfiguration configuration) { Config.RetryPolicyFactory = new RetryPolicyFactory(); } + + // Ensure the default csr factory is set if the test factory was not provided + if (Config.CsrFactory == null) + { + Config.CsrFactory = new DefaultCsrFactory(); + } } internal ApplicationConfiguration Config { get; } @@ -246,6 +253,17 @@ internal T WithRetryPolicyFactory(IRetryPolicyFactory factory) return (T)this; } + /// + /// Internal only: Allows tests to inject a custom csr factory. + /// + /// The csr factory to use. + /// The builder for chaining. + internal T WithCsrFactory(ICsrFactory factory) + { + Config.CsrFactory = factory; + return (T)this; + } + internal virtual ApplicationConfiguration BuildConfiguration() { ResolveAuthority(); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs new file mode 100644 index 0000000000..edbd183edb --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/DefaultCsrFactory.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal class DefaultCsrFactory : ICsrFactory + { + public (string csrPem, RSA privateKey) Generate(string clientId, string tenantId, CuidInfo cuid) + { + return Csr.Generate(clientId, tenantId, cuid); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs new file mode 100644 index 0000000000..84bae9409d --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ICsrFactory.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; + +namespace Microsoft.Identity.Client.ManagedIdentity.V2 +{ + internal interface ICsrFactory + { + (string csrPem, RSA privateKey) Generate(string clientId, string tenantId, CuidInfo cuid); + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index abd4e9c847..6ae0dc730f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -254,7 +254,7 @@ private async Task ExecuteCertificateRequestAsync(st protected override async Task CreateRequestAsync(string resource) { var csrMetadata = await GetCsrMetadataAsync(_requestContext, false).ConfigureAwait(false); - var (csr, privateKey) = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); + var (csr, privateKey) = _requestContext.ServiceBundle.Config.CsrFactory.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId); var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index 8342355663..e5b053dc76 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -80,6 +80,8 @@ + + @@ -165,4 +167,9 @@ + + + + + diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index ceb8bc21bd..55f510917a 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -631,7 +631,6 @@ public static MockHttpMessageHandler MockCertificateRequestResponse() { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - expectedQueryParams.Add("cuid", "%7B%22vmId%22:%22fake_vmId%22"); //expectedQueryParams.Add("uaid", "fake_client_id"); expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); expectedRequestHeaders.Add("Metadata", "true"); @@ -640,9 +639,9 @@ public static MockHttpMessageHandler MockCertificateRequestResponse() "{" + "\"client_id\": \"fake_client_id\"," + "\"tenant_id\": \"fake_tenant_id\"," + - "\"certificate\": \"fake_certificate\"," + + "\"certificate\": \"" + TestConstants.ValidPemCertificate + "\"," + "\"identity_type\": \"fake_identity_type\"," + - "\"mtls_authentication_endpoint\": \"fake_mtls_authentication_endpoint\"," + + "\"mtls_authentication_endpoint\": \"http://fake_mtls_authentication_endpoint\"," + "}"; var handler = new MockHttpMessageHandler() diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 7f8667d93f..bbebe4e18a 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -460,6 +460,15 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource( expectedQueryParams.Add("resource", resource); expectedRequestHeaders.Add("Metadata", "true"); break; + case ManagedIdentitySource.ImdsV2: + httpMessageHandler.ExpectedMethod = HttpMethod.Post; + expectedPostData = new Dictionary + { + { "client_id", "fake_client_id" }, + { "grant_type", TestConstants.ValidPemCertificate }, + { "scope", resource } + }; + break; case ManagedIdentitySource.CloudShell: httpMessageHandler.ExpectedMethod = HttpMethod.Post; expectedRequestHeaders.Add("Metadata", "true"); diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 5a2ea2986a..819af3f33c 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -583,6 +583,46 @@ public static MsalTokenResponse CreateAadTestTokenResponseWithFoci() internal const string UserAccessToken = "flMpQIKiCoiPK6qISSjmF9dGhKe47KFGPwe82BDBxBCVfYI4UiKYbBuShsjf8oGTsjN5ODeaO6k0cmZJYuNNbLyOr8JGqoxQRW9bI8j5ETpbTNf6tYpAWde9PIYj2wEBnbughVgtJsh2QxIrahie5leMpsGb1yoFzADD5gyoJq8etNUSgZwe5qkfaE9UBCUKrznKjKbsG5hBJXut5GD0QdQy3wo2PnocewrptlMzd5SsHCzUUBGA4q7ks7IfrLiQH11JyBnjBhypOX3XvuqBz4JKkpftVYfvwPWE3f5Onku6FkZJFFESyGQP9YnJVx5dQCpHH9l6ShTqOLSQduf7wxoyeAgxwPrM9Y8Kvj31IrXqiwP52x4hBsctLCqOXOZ3wMXnozMXyHpNvKMJaNgDgvBgMYhiyORkb3qKYw0gAP4659I8dK1esxJoD8I3EreDftGfNMFCgn7kFfauUQphkqx8ukqzw068R7g5TOUci1pgPcVXCAMxj0P3fTiKe1doVuF6znKYh3m7pjyzyaqb5K9VFIh4A8TXOO0MqjaVkoSWJXARTy4T0kAZBVPbO6U2BWku23yLIt43MhQTc9uf7inuirwaIgh5u7noDxYG4QZLB1CJl04Zq2gbh9GW7dqweAaC9efYTEDwhxDTPHeGTQs44e8cnWerIyZA7mq8sFuzihIiCfgZ6nNBPcx2lXKyarUtQGmjjRyOEAhs66atv3SgMhNBhontPoUhR1QEnTKeYzfaavlnf5qMZA41hijGazHyxy5FgLD5aLEpZTHN5MPQLeaEXzDMX5Wtdvq7nokiItRfLkKZtXkuSiFVltmRPcKqzGbjNRH96OQzuxLE1Mv25FYFR3PAwv6np69yScVOpNFL8CqJdT310dGnRPUKSrEqTPuMsHqVRr36j2ZUaGs6YBtcrxIxKHuPrv23FQg5fC0FgxZvKqve0hf68AocJ1HqKRy01CGQobmYpTwBByftOZYGC4KOfGd13l78kZaKLuk2gxfFuTQyr11A0L4n5tXfjlikJtr3wlTGt0KCGGXmNK1xsSoRC0VcXDOgQUu3FHblhiaYjbSvPRF09xn9tRPnUkznbsT1kPMiJ8v89ZOCtVWpvkoiy9VUVcSUpZNQwRh3wHidZAkp1xyjyVc2pIHPg6XhzJnlt77zHNiBkPxWbYt7hXBQf3QeYoMF4s0Qi1y5N72DdoSNJ3iaTwx3esAz6TeyxSh36PIz35mR5jGyGMssyaNg6lIewLPbjnizgC6xssi6mKOheDqWqBv89nIvSBOXEkKcUYsBlhBBK6BgxOIha1NAeP93RRKfyjrF7LtIoSOk3DJUx75rUJ9oyuuTt4FdSnp7ZdrIciO8vlNslPrfa7UjBdOtVHiaz9Ef91dctdADVFcwXXmcu2ypyKB1YvMbkPP7mc12TF1a8X6t0mU4s4J4IpA3SHmT5JvbQBEzOIs6ex38X3UtXSItxpaS2gKozAhAmvjt6NKMe3Jysm4bafH1kb8eB1vdwTQu3jIOGozqHC3rvqEVAt26NNKOuNYAoYYamQOSb2w8PUCuDDWs1ffLvvfyvRndZztV5C4HGGR1Tg82N291Sb7rSUYmA1rdGyJ4kPtSaiPOwMyPUs9FuZNef5Ib83D3gTcgS1gMxto5UkfSxtCDKLXtGKArOdACrRzHiiMSn3owQfyVtSXZPdeofoCzuPWcZzFLBUJR0iKWBpUkxd0N17vw45uMQpQUNGgGoyvyboKkAFlOGsEIAmrnooC3CJGVA4jHPYJnVG4xTJ37U6QL5sX95qWtjbvuD5KoT2GyWec0o62CNr09tCQsiALLC1QrfCiCGsullefbsgBB5tsOY1Kyiy4uf84qBMu20GbsJ01R8xxpJ5bh6HFRaStEK3WIy7TMJym42YMbxB3AGsGFGhNYljtuqgeUjXn1UuWskkB6QqdepFHCof6CHg0LlV0o4Iz9QKu5cfoi8jk5HKbvIGyDqCgZaC2LdugNgQ0X"; internal const string RefreshToken = "mhJDJ8wtjA3KxpRtuPAreZnMcJ2yKC2JUbpOGbRTdOCImLyQ2B4EIhv8AiA2cCEylZZfZsOsZrNsMBZZAAU9TQYYEO72QcdfnIWpAOeKkud5W2L8nMq6i9dx1EVIl09zFXhOJ79BdFbU0Eb5aUHlcqPCQjec62UKBLkZJmtMnoAa8cjvgIuxTdVM8FNdghe5nlCNTEVooKleTTEHNl2BrdyitLaWTKSP0lRqnFxriG0xWcJoSMsdS7Vt6HZd1TkwHIXycNMlCcCdUh5tOgqx1M8y8uoXK4OJ1LQmtkZvcQWcycvOCPACYakKM1pUQqwTxI6Y4HrL38sqQaSNxpF9OcFxOQWpuGodRekCbxXVbWclttIpvSOLaBhZ2ZBpcCBEeEMSmhqqYgajNwwwe9w88u0UsYKe6PBbaI48ENr02u2qBeLsIQ2HUyKlN3iVmX7u7MhgDWA3NNavMtlLmWd63NfuDgXpLI0O4cLhjAx8uoBIK8LntXPHPTxJ28o0yrszvD4gf7RdhuTq5VE15zne6iAJgIGfy7latGFzxuDMcML9OoXURHnNEHBgS9ZQCfNzYZ2O9flF1UjGpcBLEi7hHVHnrQb4y7c98dz9p62cvEMhorGx9kCwSIkOae5LheXPQkFIbsGyomNEwz3HZvR131VGAwdfmUUodvPr6LAAtmjl4sZ72PRqAo8EdQ0IFsWoypXVv51IooR87tO3uiG2DkxhIAwumOQdaJNxw1a0WS9mpQOmwFlvfbZkaIoUKgagHc8fVa1aHZntLGwH0S1iYixJiIrMnPYAeRdSp9mlHllrMX8xUIznobcZ5i8MpUYCKlUXMZ82S3XUJ5dJxARNRPxXlLJ5LPYBhUNkBLQen9Qmq3VZEV1RDJyhbGp6GAo14KsMtVAVYNmYPIgo85pCZgOwVEOBUycszu4AD3p4PT2ella4LVoqmTTMSA5GEWoeWb5JvEo222Z0oKr7UK8dGwpWRSbg8TNeODihJaTUDfErvbgaZnjIRpqfgtM5i1HfQbD7Yyft5PqyygUra7GYy7pjRrEvq95XQD8sAZ32ku9AqCo5qOB584iX881WErOoheQZokt1txqwuIMUyhVuMKNEXy70CeNTsb30ghQMZpZcXIkrLYyQCZ0gNmARhMKagCSdrpUtxudLk44yfmuwSQzBN3ifWfLZiFpU53qdPLZoTw5"; internal const string IdToken = "6GwdM7f6hHXfivavPozhaRqrbxvEysfXSMQyEKBwVgivPZTtmowsmYygchhIuxjeFFeq1ZPHjhxKFnulrvoY6TDerZY5xyOlg45bToI9Bu95qFvUrrt5r17UJcXdw4YkvEt10CcDDcLcEYw704RpVefvbpjbF24pOgIuafcAkDnbDA0Qea4ePuSC45Lw7zpJhbo9Gh8IfMX597fayBvMs3fh7frrm9KpWMCeKY3h99YSaCYjZFKp1ppvXXPE9bc4sh4pRDOfnv0Yr9J8u4elZevEE4qGddfgd3hYb18XPGRjPEMlWsh7tnwxwUm6OSZlMTHYuvwBENNMx7SUQmMeg4rCfgnbcNDkWpXCiSDVt1lLLv8F2GjYnM6De3v1Ks5lhBWx3grLggcN9LnXz92eJ1l5lTB2v0y9MgmFZ4gY43oIOW5n8G5HOx3bGOyjTw0TKKbyVa3mDj0A3QqW8eLTUJz42BNiGOf5m9prMSlpAW59CHCMJLatsj3IvGeCITsGAr3sUZEytORWUdxCfuIPwecQgU6bO7pNqNvZc1tJHHNwJlfS23ZkiFuEXqEThHYfxBCFxAzMDlzO0TOdWhvrb8hlNeAOcNhoAKxu7HXsePajKs4fU1rcdSxzNKwtASEla3p6jfJnnDtKf38RJZPaRRYMviqqWEMhjmqIvBm7sMaf8RyNNuYl7otZwmwNVCR1hzzmaTAy4kQce67FJqFba7uizrgwp9zsvK8muCHKKPvNthy7fHsxKmrBIm0bLcoePKK3wAID4kFvNQcxXp6rAOr8bLFF3bLEoYdzmF2QJz1frVZZHHPy90Cmlhw48EQN8NE2OllpdaykKt5k4rPcZQyitayNNhism30qh7eCBhcA7mm5Ja0S8X4VPlkwvgwg0mQuul6gakmja8xpnTrwiOdtao320GDmJaJA6zf3UTpNZTq9tdfBtUrjAD8RS0tNUBT3Ko8N2Lfh9ry8y9ESmRVIhch3rKY7UeefFAnkiwH2WwC57ZEsHtMP0SwKYtYKHZW9HkERCCyqOT1Mw0IavsLGFvchzMAvTnz4RwRBk6IrWgANvqT3F3Vexc2K0poKb71XZ4aMXxjqAzydGQAKpKJEJcqEvX9RD8nL76TF2LZIepiaZ3dbQImkqSjbF7aaY2JFoN9ZWlcSQKe8zdO8TIG16bF8W9R4ldDyzV39L33KcweG"; + + #region Test Certificate and Private Key (ValidPemCertificate & XmlPrivateKey) + /// + /// A test PEM-encoded X.509 certificate and its matching RSA private key. + /// These are used together in unit tests that require both a certificate and its private key. + /// The and are a matched pair: + /// - is a PEM-encoded certificate. + /// - is the corresponding RSA private key in XML format. + /// The certificate is valid for 100 years, ensuring it will not expire during the lifetime of the tests. + /// + internal const string ValidPemCertificate = @"-----BEGIN CERTIFICATE----- +MIIDATCCAemgAwIBAgIUSfjghyQB4FIS41rWfNcZHTLE/R4wDQYJKoZIhvcNAQEL +BQAwDzENMAsGA1UEAwwEVGVzdDAgFw0yNTA4MjgyMDIxMDBaGA8yMTI1MDgwNDIw +MjEwMFowDzENMAsGA1UEAwwEVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBALlc0S6TdwgQKGRl3Y/9uWNRpWo1WHiZtd1YdgCBt0rjxTqsbQUurU0B +9Kdk7QQ9srxmjimxGHaUFypbb39awqIdQQcuQvIUj5+sQh9zzCyR35bGQp8vwbna +5GlhAIbzsUi/y5kEGUMbuQN05XfoJSQrU35XZ8duQSDH5h9aDr6kuLcpDHo9/9vZ +iosPfqGPxZGtVjMvrJdVQGLJF35xD3LlX8xG2iJfVK/xYQVi3MgbRNQaL2lHtZaG +Ac1CToMUPO60xXrZkQE08hC907YTBcavUVQg4vrOaPpsCs+Fj6EJcasADAJeh1mG +Bn3kHFPCxBa2MKFraFPp53zOagTvYV0CAwEAAaNTMFEwHQYDVR0OBBYEFA9irQR/ +O6/V2JVyDEHFOdUDjAsyMB8GA1UdIwQYMBaAFA9irQR/O6/V2JVyDEHFOdUDjAsy +MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAAOxtgYjtkUDVvWz +q/lkjLTdcLjPvmH0hF34A3uvX4zcjmqF845lfvszTuhc1mx5J6YLEzKfr4TrO3D3 +g2BnDLvhupok0wEmJ9yVwbt1laim7zP09gZqnUqYM9hYKDhwgLZAaG3zGNocxDEA +U7jazMGOGF7TweB7LdNuVI6CqgDOBQ8Cy2ObuZvzCI5Y7f+HucXpiJOu1xNa2ZZp +MpQycYEvi5TD+CL5CBv2fcKQRn/+u5B3ZXCD2C9jT/RZ7rH46mIG7nC7dS4J2o4J +jmlJIUAe2U6tRay5GvEmc/nZK8hd9y4BICzrykp9ENAoy9i+uaE1GGWeNgO+irrc +rAcLwto= +-----END CERTIFICATE-----"; + internal const string XmlPrivateKey = @" + uVzRLpN3CBAoZGXdj/25Y1GlajVYeJm13Vh2AIG3SuPFOqxtBS6tTQH0p2TtBD2yvGaOKbEYdpQXKltvf1rCoh1BBy5C8hSPn6xCH3PMLJHflsZCny/BudrkaWEAhvOxSL/LmQQZQxu5A3Tld+glJCtTfldnx25BIMfmH1oOvqS4tykMej3/29mKiw9+oY/Fka1WMy+sl1VAYskXfnEPcuVfzEbaIl9Ur/FhBWLcyBtE1BovaUe1loYBzUJOgxQ87rTFetmRATTyEL3TthMFxq9RVCDi+s5o+mwKz4WPoQlxqwAMAl6HWYYGfeQcU8LEFrYwoWtoU+nnfM5qBO9hXQ== + AQAB +

3pGBJXfhILNTsbRLHmUy7YVvD75HpvMCey2aaN4gU9Jvi1s2vQFU15a8p75Yt8UYHZDr+Yqwl1Jd4J+UtWsGqGBGNB1Ae4V1dwR8zUDKxXXee7G/dCDnIu4xpkZbPD+brcULcpF/Tdq/WsTbpCNhPgjHuo8hQY3vFv1NMla8mr0=

+ 1TSgE9DfTeqk0qybQM1r83M5ZwWKV0mPQBZl1VMs+VplB6E/6JAYWCKiq9ewgocOaktK94jtEtsaDhYeyojZFBlukt1lKp4kmkUwUSEmi3EFsprNakg+Bm6t85tEm5he5mG1ivHlE3M5lBWJ2A0r1g3jWSjYJlkk2nOwFE8bmyE= + UIcU0xmsusgnYAR7qWO0KXw90tRl2GHUY/z8ATVdPPbGpQU7qObya45+c7LLJrKJJyloN8GWYynKDZuvknRG1GUBAZoT2p1PAuD8xsbKlucuuFJ3kuzUtC66iA6ss//Ps++3VJyQEvsygQT480pZxLgoi7d9sNpJx2eeprf7RYE= + zwIZqyPSrUR2ZFdTJshNWEM4KN8oQzgY7pDQrx/jOviZv57A/n1qJaj7aP4zU4juZiZU06MPDI/P7H1tyBi3LNzEj7SG1apWv7MOBre5RQqoDZJggCFEl9o+65iGNMzs16NnMVFMqmXmMfH3tN6VAXDanWca96D2N2S8QfvNQgE= + Uoxh1dskd3C0N7SQ1nJXW7FyjB+J54R5yAcd8Zk0ukunhtuzsziQH4ZoMhBuzwxRwOaw0Umj77EcdEevuvFHn6LAK/solK2lkRcuKY2QTgkbYyYOxZNa1pJJaAfgzSGsBiwiGtHXl2eFLb2jfYDa4V/SV2B6BPOVheSUQGZlyYM= + Lkq21wnu7S2T2NbzyVUVKm+mfurJqHzCxX+lIKVEkEhn5ipPo76vew7k+bUj2C5MZ+64zEK1GFANpP9mzghtmSzzI4bzIx/tanQLo2047VyU2UO0Oaskl3TKHGMkTY+ok8GKaDF02aSfxPQ5poNsWycS1/eeLFklnLkviF7mVcfCoStSHAb+8dQzxO22Mu+oN2rXHinoNDSmFzUTx8cJapQhgji+GADRKF77Sfa5tHk/hCzVUXGBHgBs1jJM9cin2BBij8PngOaAAlby4gr07/r8SZU2uuXoxEDhpxf6mRTET5Wr2hxAyhu3bpZeCc0LokckNkzJPGUG6JaXXdUcgQ== +
"; + #endregion } internal static class Adfs2019LabConstants diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs new file mode 100644 index 0000000000..6edd3936bb --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestCsrFactory.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Security.Cryptography; +using Microsoft.Identity.Client.ManagedIdentity.V2; + +namespace Microsoft.Identity.Test.Unit.Helpers +{ + internal class TestCsrFactory : ICsrFactory + { + public (string csrPem, RSA privateKey) Generate(string clientId, string tenantId, CuidInfo cuId) + { + return ("mock-csr", CreateMockRsa()); + } + + /// + /// Creates a mock private key for testing purposes by loading key parameters from an XML string. + /// The XML format is used because it allows all necessary RSA parameters to be embedded directly in the code, + /// enabling deterministic and repeatable test runs. This method returns an object rather than a string, + /// as cryptographic operations in tests require a usable key instance, not just its serialized representation. + /// + public static RSA CreateMockRsa() + { + RSA rsa = null; + +#if NET462 || NET472 + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new RSACng(); +#else + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif + rsa.FromXmlString(TestConstants.XmlPrivateKey); + return rsa; + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index dd4e30cc2e..5e1e708751 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -13,6 +13,7 @@ using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; +using OpenTelemetry.Resources; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests { @@ -20,52 +21,37 @@ namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests public class ImdsV2Tests : TestBase { private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + private readonly TestCsrFactory _testCsrFactory = new TestCsrFactory(); - // Test constants for certificate testing - using a real self-signed certificate - private const string ValidPemCertificate = @"-----BEGIN CERTIFICATE----- -MIIDUTCCAjmgAwIBAgIUPS20Ik/lV4SSwHHHJGPSlG7j5SgwDQYJKoZIhvcNAQEL -BQAwNzEWMBQGA1UEAwwNVW5pdFRlc3REdW1teTEQMA4GA1UECgwHVGVzdE9yZzEL -MAkGA1UEBhMCVVMwIBcNMjUwODI4MTcxMTA3WhgPMjI5OTA2MTIxNzExMDdaMDcx -FjAUBgNVBAMMDVVuaXRUZXN0RHVtbXkxEDAOBgNVBAoMB1Rlc3RPcmcxCzAJBgNV -BAYTAlVTMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwFX8Gqz2g4Hf -dRhrNiP8oNiZ4IwO4bra9wdCR03PEKgYv1GL1Uj0OfhKSt+8WLng43da1p3jBh2P -79IRdRLLJFX4LEJaPWW2/qUCRBpA4eMmSEBRSt1hYGtMNaKdBtxDpOxCBRpofV7Z -PPTrg682ZHAlZ5K5PK9mWfRzV1C/NmSg8FtnD24VWrdkh1waqt40OzrE16JzmPpu -2YDfXilM3G5Zq4uxHXQVCrmchBSVf7frsz+LSnMU1kn45AqDjsqufxH5+CDOtFvM -R7794+HKOdzl20U+npfbtVGKIfcWh+kRcZyrLj6DER09ehVz8VWLYgntY+8riDcl -UAfGh0RNswIDAQABo1MwUTAdBgNVHQ4EFgQUbR0id2PPztRSAoggeu0eqNFwtTAw -HwYDVR0jBBgwFoAUbR0id2PPztRSAoggeu0eqNFwtTAwDwYDVR0TAQH/BAUwAwEB -/zANBgkqhkiG9w0BAQsFAAOCAQEAje7eY+MtaBo0TmeF6fM14H5MtD7cYqdFVyIa -KeVWOxwNDtwbwRyfcDlkgcXK8gLeIZA1MNBY/juTx6qy8RsHPdNSTImDVw3t7guq -2CqrA+tqU5E+wah+XzltIvbjqTvRV/20FccfcXAkyM/aWl3WHNkFYNSziT+Ug3QQ -qPABEWvXOjo4BEgrCmQJSIprLgjtfjFSK/LS/VDpRqsSa+3mmx/Dw4FY3rfEqKzv -4RPSFxE8uF/05ByoIaAJZ2JcffDZW8PI5+qwsNatCsypyRADJE1jXLzqZnFFBLW7 -dj80Qbs0xLeK0U/Aq1kFf0stgdwbDoHaJj9Q4TlSHZuI0TnjSg== ------END CERTIFICATE-----"; - - public async Task ImdsV2HappyPathAsync() + //TODO: Clean up this method. Use constants, etc. + [TestMethod] + public async Task ImdsV2SAMIHappyPathAsync() { using (var httpManager = new MockHttpManager()) { + // TODO: Implement DataTestMethod. SAMI + UAMI //ManagedIdentityId managedIdentityId = userAssignedId == null // ? ManagedIdentityId.SystemAssigned // : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager) - .WithRetryPolicyFactory(_testRetryPolicyFactory); + .WithRetryPolicyFactory(_testRetryPolicyFactory) + .WithCsrFactory(_testCsrFactory); // Disabling shared cache options to avoid cross test pollution. miBuilder.Config.AccessorOptions = null; var mi = miBuilder.Build(); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); - httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // initial probe + httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // do it again, since CsrMetadata from initial probe is not cached httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); - // TODO: add a mock handler for acquiring the entra token over an mTLS channel - //httpManager.AddMockHandler() + httpManager.AddManagedIdentityMockHandler( + "http://fake_mtls_authentication_endpoint/fake_tenant_id/oauth2/v2.0/token", + "https://management.azure.com", + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.ImdsV2); - // this will fail, see TODO above var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); @@ -73,7 +59,6 @@ public async Task ImdsV2HappyPathAsync() Assert.IsNotNull(result.AccessToken); Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); - // this will fail, see TODO above result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) .ExecuteAsync().ConfigureAwait(false); @@ -259,7 +244,7 @@ public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() // For this test, we just want to verify that the method doesn't crash // The actual certificate/private key matching isn't critical for the unit test var exception = Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(ValidPemCertificate, rsa)); + imdsV2Source.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, rsa)); // The test should fail with a NotSupportedException because the RSA key doesn't match // the certificate, but this validates that the method is working correctly @@ -320,7 +305,7 @@ public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(ValidPemCertificate, null)); + imdsV2Source.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, null)); } [TestMethod] From d839578dcb2a8474d4d0bccb7ae839acecf16678 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 28 Aug 2025 18:21:48 -0400 Subject: [PATCH 39/41] Undid changes to csproj --- .../Microsoft.Identity.Client.csproj | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj index e5b053dc76..8342355663 100644 --- a/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj +++ b/src/client/Microsoft.Identity.Client/Microsoft.Identity.Client.csproj @@ -80,8 +80,6 @@ - - @@ -167,9 +165,4 @@ - - - - - From 71182357544388ea407fd4b26b71fd0137f4eb16 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 29 Aug 2025 14:23:35 -0400 Subject: [PATCH 40/41] Implemented feedback --- .../V2/ImdsV2ManagedIdentitySource.cs | 135 +----------------- .../Shared/CommonCryptographyManager.cs | 109 ++++++++++++++ .../ManagedIdentityTests/ImdsV2Tests.cs | 40 +++--- 3 files changed, 135 insertions(+), 149 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs index 6ae0dc730f..d329bfbfa8 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs @@ -5,13 +5,12 @@ using System.Collections.Generic; using System.Net; using System.Net.Http; -using System.Security.Cryptography; -using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.ManagedIdentity.V2 @@ -28,6 +27,10 @@ public static async Task GetCsrMetadataAsync( RequestContext requestContext, bool probeMode) { +#if NET462 + requestContext.Logger.Info(() => "[Managed Identity] IMDSv2 flow is not supported on .NET Framework 4.6.2. Cryptographic operations required for managed identity authentication are unavailable on this platform. Skipping IMDSv2 probe."); + return await Task.FromResult(null).ConfigureAwait(false); +#else var queryParams = ImdsV2QueryParamsHelper(requestContext); var headers = new Dictionary @@ -93,6 +96,7 @@ public static async Task GetCsrMetadataAsync( } return TryCreateCsrMetadata(response, requestContext.Logger, probeMode); +#endif } private static void ThrowProbeFailedException( @@ -259,7 +263,7 @@ protected override async Task CreateRequestAsync(string var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false); // transform certificateRequestResponse.Certificate to x509 with private key - var mtlsCertificate = AttachPrivateKeyToCert( + var mtlsCertificate = CommonCryptographyManager.AttachPrivateKeyToCert( certificateRequestResponse.Certificate, privateKey); @@ -274,131 +278,6 @@ protected override async Task CreateRequestAsync(string return request; } - /// - /// Attaches a private key to a certificate for use in mTLS authentication. - /// - /// The certificate in PEM format - /// The RSA private key to attach - /// An X509Certificate2 with the private key attached - /// Thrown when certificatePem or privateKey is null - /// Thrown when certificatePem is not a valid PEM certificate - /// Thrown when the certificate cannot be parsed - internal X509Certificate2 AttachPrivateKeyToCert(string certificatePem, RSA privateKey) - { - if (string.IsNullOrEmpty(certificatePem)) - throw new ArgumentNullException(nameof(certificatePem)); - if (privateKey == null) - throw new ArgumentNullException(nameof(privateKey)); - - X509Certificate2 certificate; - -#if NET8_0_OR_GREATER - // .NET 8.0+ has direct PEM parsing support - certificate = X509Certificate2.CreateFromPem(certificatePem); - // Attach the private key and return a new certificate instance - return certificate.CopyWithPrivateKey(privateKey); -#else - // .NET Framework 4.7.2 and .NET Standard 2.0 - manual PEM parsing and private key attachment - certificate = ParseCertificateFromPem(certificatePem); - return AttachPrivateKeyToOlderFrameworks(certificate, privateKey); -#endif - } - -#if !NET8_0_OR_GREATER - /// - /// Parses a certificate from PEM format for older .NET versions. - /// - /// The certificate in PEM format - /// An X509Certificate2 instance - /// Thrown when the PEM format is invalid - /// Thrown when the Base64 content cannot be decoded - internal static X509Certificate2 ParseCertificateFromPem(string certificatePem) - { - const string CertBeginMarker = "-----BEGIN CERTIFICATE-----"; - const string CertEndMarker = "-----END CERTIFICATE-----"; - - int startIndex = certificatePem.IndexOf(CertBeginMarker, StringComparison.Ordinal); - if (startIndex == -1) - { - throw new ArgumentException("Invalid PEM format: missing BEGIN CERTIFICATE marker", nameof(certificatePem)); - } - - startIndex += CertBeginMarker.Length; - int endIndex = certificatePem.IndexOf(CertEndMarker, startIndex, StringComparison.Ordinal); - if (endIndex == -1) - { - throw new ArgumentException("Invalid PEM format: missing END CERTIFICATE marker", nameof(certificatePem)); - } - - string base64Content = certificatePem.Substring(startIndex, endIndex - startIndex) - .Replace("\r", "") - .Replace("\n", "") - .Replace(" ", ""); - - if (string.IsNullOrEmpty(base64Content)) - { - throw new ArgumentException("Invalid PEM format: no certificate content found", nameof(certificatePem)); - } - - try - { - byte[] certBytes = Convert.FromBase64String(base64Content); - return new X509Certificate2(certBytes); - } - catch (FormatException ex) - { - throw new FormatException("Invalid PEM format: certificate content is not valid Base64", ex); - } - } - - /// - /// Attaches a private key to a certificate for older .NET Framework versions. - /// This method uses the older RSACng approach for .NET Framework 4.7.2 and .NET Standard 2.0. - /// - /// The certificate without private key - /// The RSA private key to attach - /// An X509Certificate2 with the private key attached - /// Thrown when private key attachment fails - internal X509Certificate2 AttachPrivateKeyToOlderFrameworks(X509Certificate2 certificate, RSA privateKey) - { - try - { - // For older frameworks, we need to use the legacy approach with RSACryptoServiceProvider - // First, export the RSA parameters from the provided private key - var parameters = privateKey.ExportParameters(includePrivateParameters: true); - - // Create a new RSACryptoServiceProvider with the correct key size - int keySize = parameters.Modulus.Length * 8; - var rsaProvider = new RSACryptoServiceProvider(keySize); - - try - { - // Import the parameters into the new provider - rsaProvider.ImportParameters(parameters); - - // Create a new certificate instance from the raw data - var certWithPrivateKey = new X509Certificate2(certificate.RawData); - - // Assign the private key using the legacy property - certWithPrivateKey.PrivateKey = rsaProvider; - - return certWithPrivateKey; - } - catch - { - // Clean up the RSA provider if something goes wrong - rsaProvider?.Dispose(); - throw; - } - } - catch (Exception ex) - { - throw new NotSupportedException( - "Failed to attach private key to certificate on this .NET Framework version.", ex); - } - } -#endif - private static string ImdsV2QueryParamsHelper(RequestContext requestContext) { var queryParams = $"cred-api-version={ImdsV2ApiVersion}"; diff --git a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs index 20fc279fc4..187df64051 100644 --- a/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs +++ b/src/client/Microsoft.Identity.Client/PlatformsCommon/Shared/CommonCryptographyManager.cs @@ -111,5 +111,114 @@ byte[] SignDataAndCacheProvider(string message) return signedData; } } + + /// + /// Attaches a private key to a certificate for use in mTLS authentication. + /// + /// The certificate in PEM format + /// The RSA private key to attach + /// An X509Certificate2 with the private key attached + /// Thrown when certificatePem or privateKey is null + /// Thrown when certificatePem is not a valid PEM certificate + /// Thrown when the certificate cannot be parsed + internal static X509Certificate2 AttachPrivateKeyToCert(string certificatePem, RSA privateKey) + { + if (string.IsNullOrEmpty(certificatePem)) + throw new ArgumentNullException(nameof(certificatePem)); + if (privateKey == null) + throw new ArgumentNullException(nameof(privateKey)); + + X509Certificate2 certificate; + +#if NET8_0_OR_GREATER + // .NET 8.0+ has direct PEM parsing support + certificate = X509Certificate2.CreateFromPem(certificatePem); + // Attach the private key and return a new certificate instance + return certificate.CopyWithPrivateKey(privateKey); +#else + // .NET Framework 4.7.2 and .NET Standard 2.0 - manual PEM parsing and private key attachment + certificate = ParseCertificateFromPem(certificatePem); + return AttachPrivateKeyToOlderFrameworks(certificate, privateKey); +#endif + } + +#if !NET8_0_OR_GREATER + /// + /// Parses a certificate from PEM format for older .NET versions. + /// + /// The certificate in PEM format + /// An X509Certificate2 instance + /// Thrown when the PEM format is invalid + /// Thrown when the Base64 content cannot be decoded + private static X509Certificate2 ParseCertificateFromPem(string certificatePem) + { + const string CertBeginMarker = "-----BEGIN CERTIFICATE-----"; + const string CertEndMarker = "-----END CERTIFICATE-----"; + + int startIndex = certificatePem.IndexOf(CertBeginMarker, StringComparison.Ordinal); + if (startIndex == -1) + { + throw new ArgumentException("Invalid PEM format: missing BEGIN CERTIFICATE marker", nameof(certificatePem)); + } + + startIndex += CertBeginMarker.Length; + int endIndex = certificatePem.IndexOf(CertEndMarker, startIndex, StringComparison.Ordinal); + if (endIndex == -1) + { + throw new ArgumentException("Invalid PEM format: missing END CERTIFICATE marker", nameof(certificatePem)); + } + + string base64Content = certificatePem.Substring(startIndex, endIndex - startIndex) + .Replace("\r", "") + .Replace("\n", "") + .Replace(" ", ""); + + if (string.IsNullOrEmpty(base64Content)) + { + throw new ArgumentException("Invalid PEM format: no certificate content found", nameof(certificatePem)); + } + + try + { + byte[] certBytes = Convert.FromBase64String(base64Content); + return new X509Certificate2(certBytes); + } + catch (FormatException ex) + { + throw new FormatException("Invalid PEM format: certificate content is not valid Base64", ex); + } + } + + /// + /// Attaches a private key to a certificate for older .NET Framework versions. + /// This method uses the older RSACng approach for .NET Framework 4.7.2 and .NET Standard 2.0. + /// + /// The certificate without private key + /// The RSA private key to attach + /// An X509Certificate2 with the private key attached + /// Thrown when private key attachment fails + private static X509Certificate2 AttachPrivateKeyToOlderFrameworks(X509Certificate2 certificate, RSA privateKey) + { + // For older frameworks, we need to use the legacy approach with RSACryptoServiceProvider + // First, export the RSA parameters from the provided private key + var parameters = privateKey.ExportParameters(includePrivateParameters: true); + + // Create a new RSACryptoServiceProvider with the correct key size + int keySize = parameters.Modulus.Length * 8; + using (var rsaProvider = new RSACryptoServiceProvider(keySize)) + { + // Import the parameters into the new provider + rsaProvider.ImportParameters(parameters); + + // Create a new certificate instance from the raw data + var certWithPrivateKey = new X509Certificate2(certificate.RawData); + + // Assign the private key using the legacy property + certWithPrivateKey.PrivateKey = rsaProvider; + + return certWithPrivateKey; + } + } +#endif } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 5e1e708751..c8aa2093ce 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -10,10 +10,10 @@ using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; +using Microsoft.Identity.Client.PlatformsCommon.Shared; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; -using OpenTelemetry.Resources; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests { @@ -243,14 +243,12 @@ public void AttachPrivateKeyToCert_ValidInputs_ReturnsValidCertificate() { // For this test, we just want to verify that the method doesn't crash // The actual certificate/private key matching isn't critical for the unit test - var exception = Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, rsa)); + var exception = Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, rsa)); - // The test should fail with a NotSupportedException because the RSA key doesn't match + // The test should fail with a CryptographicUnexpectedOperationException because the RSA key doesn't match // the certificate, but this validates that the method is working correctly - Assert.AreEqual( - "Failed to attach private key to certificate on this .NET Framework version.", - exception.Message); + Assert.IsNotNull(exception.Message); } } @@ -268,8 +266,8 @@ public void AttachPrivateKeyToCert_NullCertificatePem_ThrowsArgumentNullExceptio using (RSA rsa = RSA.Create()) { - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(null, rsa)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(null, rsa)); } } @@ -287,8 +285,8 @@ public void AttachPrivateKeyToCert_EmptyCertificatePem_ThrowsArgumentNullExcepti using (RSA rsa = RSA.Create()) { - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert("", rsa)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert("", rsa)); } } @@ -304,8 +302,8 @@ public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() var requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); var imdsV2Source = new ImdsV2ManagedIdentitySource(requestContext); - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, null)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(TestConstants.ValidPemCertificate, null)); } [TestMethod] @@ -324,8 +322,8 @@ public void AttachPrivateKeyToCert_InvalidPemFormat_ThrowsArgumentException() using (RSA rsa = RSA.Create()) { - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(InvalidPemNoCertMarker, rsa)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(InvalidPemNoCertMarker, rsa)); } } @@ -346,8 +344,8 @@ public void AttachPrivateKeyToCert_MissingBeginMarker_ThrowsArgumentException() using (RSA rsa = RSA.Create()) { - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(InvalidPemMissingBeginMarker, rsa)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(InvalidPemMissingBeginMarker, rsa)); } } @@ -367,8 +365,8 @@ public void AttachPrivateKeyToCert_MissingEndMarker_ThrowsArgumentException() using (RSA rsa = RSA.Create()) { - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(InvalidPemMissingEndMarker, rsa)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(InvalidPemMissingEndMarker, rsa)); } } @@ -390,8 +388,8 @@ public void AttachPrivateKeyToCert_BadBase64Content_ThrowsFormatException() using (RSA rsa = RSA.Create()) { - Assert.ThrowsException(() => - imdsV2Source.AttachPrivateKeyToCert(InvalidPemBadBase64, rsa)); + Assert.ThrowsException(() => + CommonCryptographyManager.AttachPrivateKeyToCert(InvalidPemBadBase64, rsa)); } } From 8fcac990f12dc9b9bcfa35331245256e2e7fd7a8 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Fri, 29 Aug 2025 16:47:36 -0400 Subject: [PATCH 41/41] Improved unit tests. Added UAMI unit tests. --- .../Core/Mocks/MockHelpers.cs | 35 +++++++++--- .../Core/Mocks/MockHttpManagerExtensions.cs | 2 +- .../TestConstants.cs | 1 + .../ManagedIdentityTests/ImdsV2Tests.cs | 56 ++++++++++++++++--- 4 files changed, 77 insertions(+), 17 deletions(-) diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 55f510917a..04665fc0dd 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -8,12 +8,16 @@ using System.Net; using System.Net.Http; using System.Net.Http.Headers; +using Castle.Core.Logging; using Microsoft.Identity.Client; +using Microsoft.Identity.Client.AppConfig; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.ManagedIdentity.V2; using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Client.Utils; using Microsoft.Identity.Test.Unit; +using Microsoft.VisualStudio.TestTools.UnitTesting.Logging; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; namespace Microsoft.Identity.Test.Common.Core.Mocks { @@ -587,18 +591,25 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce public static MockHttpMessageHandler MockCsrResponse( HttpStatusCode statusCode = HttpStatusCode.OK, - string responseServerHeader = "IMDS/150.870.65.1854") + string responseServerHeader = "IMDS/150.870.65.1854", + UserAssignedIdentityId idType = UserAssignedIdentityId.None, + string userAssignedId = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); + if (idType != UserAssignedIdentityId.None && userAssignedId != null) + { + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null); + expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); + } expectedQueryParams.Add("cred-api-version", "2.0"); expectedRequestHeaders.Add("Metadata", "true"); string content = "{" + "\"cuId\": { \"vmId\": \"fake_vmId\" }," + - "\"clientId\": \"fake_client_id\"," + - "\"tenantId\": \"fake_tenant_id\"," + + "\"clientId\": \"" + TestConstants.ClientId + "\"," + + "\"tenantId\": \"" + TestConstants.TenantId + "\"," + "\"attestationEndpoint\": \"fake_attestation_endpoint\"" + "}"; @@ -627,21 +638,27 @@ public static MockHttpMessageHandler MockCsrResponseFailure() return MockCsrResponse(HttpStatusCode.BadRequest); } - public static MockHttpMessageHandler MockCertificateRequestResponse() + public static MockHttpMessageHandler MockCertificateRequestResponse( + UserAssignedIdentityId idType = UserAssignedIdentityId.None, + string userAssignedId = null) { IDictionary expectedQueryParams = new Dictionary(); IDictionary expectedRequestHeaders = new Dictionary(); - //expectedQueryParams.Add("uaid", "fake_client_id"); + if (idType != UserAssignedIdentityId.None && userAssignedId != null) + { + var userAssignedIdQueryParam = ImdsManagedIdentitySource.GetUserAssignedIdQueryParam((ManagedIdentityIdType)idType, userAssignedId, null); + expectedQueryParams.Add(userAssignedIdQueryParam.Value.Key, userAssignedIdQueryParam.Value.Value); + } expectedQueryParams.Add("cred-api-version", ImdsV2ManagedIdentitySource.ImdsV2ApiVersion); expectedRequestHeaders.Add("Metadata", "true"); string content = "{" + - "\"client_id\": \"fake_client_id\"," + - "\"tenant_id\": \"fake_tenant_id\"," + + "\"client_id\": \"" + TestConstants.ClientId + "\"," + + "\"tenant_id\": \"" + TestConstants.TenantId + "\"," + "\"certificate\": \"" + TestConstants.ValidPemCertificate + "\"," + - "\"identity_type\": \"fake_identity_type\"," + - "\"mtls_authentication_endpoint\": \"http://fake_mtls_authentication_endpoint\"," + + "\"identity_type\": \"fake_identity_type\"," + // "SystemAssigned" or "UserAssigned", it doesn't matter for these tests + "\"mtls_authentication_endpoint\": \"" + TestConstants.MtlsAuthenticationEndpoint + "\"," + "}"; var handler = new MockHttpMessageHandler() diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index bbebe4e18a..017213d275 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -464,7 +464,7 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource( httpMessageHandler.ExpectedMethod = HttpMethod.Post; expectedPostData = new Dictionary { - { "client_id", "fake_client_id" }, + { "client_id", TestConstants.ClientId }, { "grant_type", TestConstants.ValidPemCertificate }, { "scope", resource } }; diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 819af3f33c..35107976c7 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -156,6 +156,7 @@ public static HashSet s_scope public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; public const string VmId = "test-vm-id"; public const string VmssId = "test-vmss-id"; + public const string MtlsAuthenticationEndpoint = "http://fake_mtls_authentication_endpoint"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index c8aa2093ce..25322ec08b 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -14,6 +14,7 @@ using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests { @@ -23,16 +24,11 @@ public class ImdsV2Tests : TestBase private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); private readonly TestCsrFactory _testCsrFactory = new TestCsrFactory(); - //TODO: Clean up this method. Use constants, etc. [TestMethod] public async Task ImdsV2SAMIHappyPathAsync() { using (var httpManager = new MockHttpManager()) { - // TODO: Implement DataTestMethod. SAMI + UAMI - //ManagedIdentityId managedIdentityId = userAssignedId == null - // ? ManagedIdentityId.SystemAssigned - // : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) .WithHttpManager(httpManager) .WithRetryPolicyFactory(_testRetryPolicyFactory) @@ -47,8 +43,54 @@ public async Task ImdsV2SAMIHappyPathAsync() httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); // do it again, since CsrMetadata from initial probe is not cached httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse()); httpManager.AddManagedIdentityMockHandler( - "http://fake_mtls_authentication_endpoint/fake_tenant_id/oauth2/v2.0/token", - "https://management.azure.com", + $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.ImdsV2); + + var result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + + [DataTestMethod] + [DataRow(UserAssignedIdentityId.ClientId, TestConstants.ClientId)] + [DataRow(UserAssignedIdentityId.ResourceId, TestConstants.MiResourceId)] + [DataRow(UserAssignedIdentityId.ObjectId, TestConstants.ObjectId)] + public async Task ImdsV2UAMIHappyPathAsync( + UserAssignedIdentityId userAssignedIdentityId, + string userAssignedId) + { + using (var httpManager = new MockHttpManager()) + { + var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId); + miBuilder + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory) + .WithCsrFactory(_testCsrFactory); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // initial probe + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(idType: userAssignedIdentityId, userAssignedId: userAssignedId)); // do it again, since CsrMetadata from initial probe is not cached + httpManager.AddMockHandler(MockHelpers.MockCertificateRequestResponse(userAssignedIdentityId, userAssignedId)); + httpManager.AddManagedIdentityMockHandler( + $"{TestConstants.MtlsAuthenticationEndpoint}/{TestConstants.TenantId}{ImdsV2ManagedIdentitySource.AcquireEntraTokenPath}", + ManagedIdentityTests.Resource, MockHelpers.GetMsiSuccessfulResponse(), ManagedIdentitySource.ImdsV2);