diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs new file mode 100644 index 0000000000..4391fba4be --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CertificateRequestResponse.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if SUPPORTS_SYSTEM_TEXT_JSON + using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute; +#else + using Microsoft.Identity.Json; +#endif + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + /// + /// Represents the response for a Managed Identity CSR request. + /// + internal class CertificateRequestResponse + { + [JsonProperty("client_id")] + public string ClientId { get; set; } + + [JsonProperty("tenant_id")] + public string TenantId { get; set; } + + [JsonProperty("client_credential")] + public string ClientCredential { get; set; } + + [JsonProperty("regional_token_url")] + public string RegionalTokenUrl { get; set; } + + [JsonProperty("expires_in")] + public int ExpiresIn { get; set; } + + [JsonProperty("refresh_in")] + public int RefreshIn { get; set; } + + public CertificateRequestResponse() { } + + public static bool IsValid(CertificateRequestResponse certificateRequestResponse) + { + if (string.IsNullOrEmpty(certificateRequestResponse.ClientId) || + string.IsNullOrEmpty(certificateRequestResponse.TenantId) || + string.IsNullOrEmpty(certificateRequestResponse.ClientCredential) || + string.IsNullOrEmpty(certificateRequestResponse.RegionalTokenUrl) || + certificateRequestResponse.ExpiresIn <= 0 || + certificateRequestResponse.RefreshIn <= 0) + { + return false; + } + + return true; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs new file mode 100644 index 0000000000..fdc8584cd3 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/Csr.cs @@ -0,0 +1,477 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Security.Cryptography; +using System.Text; +using Microsoft.Identity.Client.Utils; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal class Csr + { + public string Pem { get; } + + public Csr(string pem) + { + Pem = pem ?? throw new ArgumentNullException(nameof(pem)); + } + + /// + /// Generates a CSR for the given client, tenant, and CUID info. + /// + /// Managed Identity client_id. + /// AAD tenant_id. + /// CuidInfo object containing required VMID and optional VMSSID. + /// CsrRequest containing the PEM CSR. + public static Csr Generate(string clientId, string tenantId, CuidInfo cuid) + { + if (string.IsNullOrEmpty(clientId)) + throw new ArgumentException("clientId must not be null or empty.", nameof(clientId)); + if (string.IsNullOrEmpty(tenantId)) + throw new ArgumentException("tenantId must not be null or empty.", nameof(tenantId)); + if (cuid == null) + throw new ArgumentNullException(nameof(cuid)); + if (string.IsNullOrEmpty(cuid.Vmid)) + throw new ArgumentException("cuid.Vmid must not be null or empty.", nameof(cuid.Vmid)); + + string pemCsr = GeneratePkcs10Csr(clientId, tenantId, cuid); + return new Csr(pemCsr); + } + + /// + /// Generates a PKCS#10 Certificate Signing Request in PEM format. + /// + private static string GeneratePkcs10Csr(string clientId, string tenantId, CuidInfo cuid) + { + // Generate RSA key pair (2048-bit) + RSA rsa = CreateRsaKeyPair(); + + try + { + // Build the CSR components + byte[] certificationRequestInfo = BuildCertificationRequestInfo(clientId, tenantId, cuid, rsa); + byte[] signatureAlgorithm = BuildSignatureAlgorithmIdentifier(); + byte[] signature = SignCertificationRequestInfo(certificationRequestInfo, rsa); + + // Combine into final CSR structure + byte[] csrBytes = BuildFinalCsr(certificationRequestInfo, signatureAlgorithm, signature); + + // Convert to PEM format + return ConvertToPem(csrBytes); + } + finally + { + rsa?.Dispose(); + } + } + + /// + /// Creates a 2048-bit RSA key pair that supports PSS padding across all target frameworks. + /// + /// + /// On .NET Framework 4.6.2/4.7.2 (Windows-only), explicitly uses RSACng (also Windows-only) + /// to ensure PSS padding support, as RSA.Create() may return RSACryptoServiceProvider + /// which doesn't support PSS. + /// On .NET Standard 2.0 and .NET 8.0+ (cross-platform), uses RSA.Create() which returns + /// modern implementations that support PSS: RSACng on Windows, OpenSSL-based on Linux/macOS. + /// + /// An RSA instance configured for 2048-bit keys with PSS padding capability. + private static RSA CreateRsaKeyPair() + { + // TODO: use the strongest key on the machine i.e. a TPM key + RSA rsa = null; + +#if NET462 || NET472 + // .NET Framework runs only on Windows, so RSACng (Windows-only) is always available + rsa = new RSACng(); +#else + // Cross-platform .NET - RSA.Create() returns appropriate PSS-capable implementation + rsa = RSA.Create(); +#endif + rsa.KeySize = 2048; + return rsa; + } + + /// + /// Builds the CertificationRequestInfo structure containing subject, public key, and attributes. + /// + private static byte[] BuildCertificationRequestInfo(string clientId, string tenantId, CuidInfo cuid, RSA rsa) + { + var components = new List(); + + // Version (INTEGER 0) + components.Add(EncodeAsn1Integer(new byte[] { 0x00 })); + + // Subject: CN=, DC= + components.Add(BuildSubjectName(clientId, tenantId)); + + // SubjectPublicKeyInfo + components.Add(BuildSubjectPublicKeyInfo(rsa)); + + // Attributes (including CUID) + components.Add(BuildAttributes(cuid)); + + return EncodeAsn1Sequence(components.ToArray()); + } + + /// + /// Builds the X.500 Distinguished Name for the subject. + /// + private static byte[] BuildSubjectName(string clientId, string tenantId) + { + var rdnSequence = new List(); + + // CN= + byte[] cnOid = EncodeAsn1ObjectIdentifier(new int[] { 2, 5, 4, 3 }); // commonName OID + byte[] cnValue = EncodeAsn1Utf8String(clientId); + byte[] cnAttributeValue = EncodeAsn1Sequence(new[] { cnOid, cnValue }); + rdnSequence.Add(EncodeAsn1Set(new[] { cnAttributeValue })); + + // DC= + byte[] dcOid = EncodeAsn1ObjectIdentifier(new int[] { 0, 9, 2342, 19200300, 100, 1, 25 }); // domainComponent OID + byte[] dcValue = EncodeAsn1Utf8String(tenantId); + byte[] dcAttributeValue = EncodeAsn1Sequence(new[] { dcOid, dcValue }); + rdnSequence.Add(EncodeAsn1Set(new[] { dcAttributeValue })); + + return EncodeAsn1Sequence(rdnSequence.ToArray()); + } + + /// + /// Builds the SubjectPublicKeyInfo structure containing the RSA public key. + /// + private static byte[] BuildSubjectPublicKeyInfo(RSA rsa) + { + RSAParameters rsaParams = rsa.ExportParameters(false); + + // RSA Public Key structure + byte[] modulus = EncodeAsn1Integer(rsaParams.Modulus); + byte[] exponent = EncodeAsn1Integer(rsaParams.Exponent); + byte[] rsaPublicKey = EncodeAsn1Sequence(new[] { modulus, exponent }); + + // Algorithm identifier for RSA encryption + byte[] rsaOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 1 }); // RSA encryption OID + byte[] nullParam = EncodeAsn1Null(); + byte[] algorithmIdentifier = EncodeAsn1Sequence(new[] { rsaOid, nullParam }); + + // SubjectPublicKeyInfo + byte[] publicKeyBitString = EncodeAsn1BitString(rsaPublicKey); + return EncodeAsn1Sequence(new[] { algorithmIdentifier, publicKeyBitString }); + } + + /// + /// Builds the attributes section including the CUID extension. + /// + private static byte[] BuildAttributes(CuidInfo cuid) + { + var attributes = new List(); + + // CUID attribute (OID 1.2.840.113549.1.9.7) + // Serialize CuidInfo as JSON object string using existing JSON serialization + byte[] cuidOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 9, 7 }); + string cuidValue = JsonHelper.SerializeToJson(cuid); + byte[] cuidData = EncodeAsn1PrintableString(cuidValue); + byte[] cuidAttributeValue = EncodeAsn1Set(new[] { cuidData }); + byte[] cuidAttribute = EncodeAsn1Sequence(new[] { cuidOid, cuidAttributeValue }); + attributes.Add(cuidAttribute); + + return EncodeAsn1ContextSpecific(0, EncodeAsn1SequenceRaw(attributes.ToArray())); + } + + /// + /// Builds the signature algorithm identifier for RSASSA-PSS with SHA256. + /// + private static byte[] BuildSignatureAlgorithmIdentifier() + { + byte[] rsassaPssOid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 10 }); // RSASSA-PSS OID + byte[] pssParams = BuildPssParameters(); + return EncodeAsn1Sequence(new[] { rsassaPssOid, pssParams }); + } + + /// + /// Builds the RSASSA-PSS parameters for SHA256 with MGF1. + /// + private static byte[] BuildPssParameters() + { + var parameters = new List(); + + // hashAlgorithm [0] AlgorithmIdentifier DEFAULT sha1 + // We explicitly specify SHA256 since default is SHA1 + byte[] sha256Oid = EncodeAsn1ObjectIdentifier(new int[] { 2, 16, 840, 1, 101, 3, 4, 2, 1 }); // SHA256 OID + byte[] sha256Null = EncodeAsn1Null(); + byte[] hashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); + byte[] hashAlgorithmParam = EncodeAsn1ContextSpecific(0, hashAlgorithm); + parameters.Add(hashAlgorithmParam); + + // maskGenAlgorithm [1] AlgorithmIdentifier DEFAULT mgf1SHA1 + // We explicitly specify MGF1 with SHA256 + byte[] mgf1Oid = EncodeAsn1ObjectIdentifier(new int[] { 1, 2, 840, 113549, 1, 1, 8 }); // MGF1 OID + byte[] mgf1HashAlgorithm = EncodeAsn1Sequence(new[] { sha256Oid, sha256Null }); // MGF1 uses SHA256 + byte[] maskGenAlgorithm = EncodeAsn1Sequence(new[] { mgf1Oid, mgf1HashAlgorithm }); + byte[] maskGenAlgorithmParam = EncodeAsn1ContextSpecific(1, maskGenAlgorithm); + parameters.Add(maskGenAlgorithmParam); + + // saltLength [2] INTEGER DEFAULT 20 + // We explicitly specify 32 for SHA256 (hash length) + byte[] saltLength = EncodeAsn1Integer(32); + byte[] saltLengthParam = EncodeAsn1ContextSpecific(2, saltLength); + parameters.Add(saltLengthParam); + + // trailerField [3] INTEGER DEFAULT 1 + // Default value is 1 (0xBC), so we omit this parameter + + return EncodeAsn1Sequence(parameters.ToArray()); + } + + /// + /// Signs the CertificationRequestInfo with SHA256withRSA-PSS. + /// + private static byte[] SignCertificationRequestInfo(byte[] certificationRequestInfo, RSA rsa) + { + return rsa.SignData(certificationRequestInfo, HashAlgorithmName.SHA256, RSASignaturePadding.Pss); + } + + /// + /// Combines all components into the final CSR structure. + /// + private static byte[] BuildFinalCsr(byte[] certificationRequestInfo, byte[] signatureAlgorithm, byte[] signature) + { + byte[] signatureBitString = EncodeAsn1BitString(signature); + return EncodeAsn1Sequence(new[] { certificationRequestInfo, signatureAlgorithm, signatureBitString }); + } + + /// + /// Converts DER-encoded bytes to PEM format. + /// + private static string ConvertToPem(byte[] derBytes) + { + string base64 = Convert.ToBase64String(derBytes); + var sb = new StringBuilder(); + sb.AppendLine("-----BEGIN CERTIFICATE REQUEST-----"); + + // Split into 64-character lines + for (int i = 0; i < base64.Length; i += 64) + { + int length = Math.Min(64, base64.Length - i); + sb.AppendLine(base64.Substring(i, length)); + } + + sb.AppendLine("-----END CERTIFICATE REQUEST-----"); + return sb.ToString(); + } + + #region ASN.1 Encoding Helpers + + /// + /// Encodes an ASN.1 SEQUENCE. + /// + private static byte[] EncodeAsn1Sequence(byte[][] components) + { + return EncodeAsn1Tag(0x30, ConcatenateByteArrays(components)); + } + + /// + /// Encodes an ASN.1 SEQUENCE without the outer tag (for raw concatenation). + /// + private static byte[] EncodeAsn1SequenceRaw(byte[][] components) + { + return ConcatenateByteArrays(components); + } + + /// + /// Encodes an ASN.1 SET. + /// + private static byte[] EncodeAsn1Set(byte[][] components) + { + return EncodeAsn1Tag(0x31, ConcatenateByteArrays(components)); + } + + /// + /// Encodes an ASN.1 INTEGER. + /// + private static byte[] EncodeAsn1Integer(byte[] value) + { + // Ensure positive integer (prepend 0x00 if high bit is set) + if (value != null && value.Length > 0 && (value[0] & 0x80) != 0) + { + byte[] paddedValue = new byte[value.Length + 1]; + paddedValue[0] = 0x00; + Array.Copy(value, 0, paddedValue, 1, value.Length); + value = paddedValue; + } + return EncodeAsn1Tag(0x02, value ?? new byte[0]); + } + + /// + /// Encodes an ASN.1 INTEGER from an integer value. + /// + private static byte[] EncodeAsn1Integer(int value) + { + if (value == 0) + return EncodeAsn1Tag(0x02, new byte[] { 0x00 }); + + var bytes = new List(); + int temp = value; + while (temp > 0) + { + bytes.Insert(0, (byte)(temp & 0xFF)); + temp >>= 8; + } + + return EncodeAsn1Integer(bytes.ToArray()); + } + + /// + /// Encodes an ASN.1 BIT STRING. + /// + private static byte[] EncodeAsn1BitString(byte[] value) + { + byte[] bitStringValue = new byte[value.Length + 1]; + bitStringValue[0] = 0x00; // No unused bits + Array.Copy(value, 0, bitStringValue, 1, value.Length); + return EncodeAsn1Tag(0x03, bitStringValue); + } + + /// + /// Encodes an ASN.1 UTF8String. + /// + private static byte[] EncodeAsn1Utf8String(string value) + { + byte[] utf8Bytes = Encoding.UTF8.GetBytes(value); + return EncodeAsn1Tag(0x0C, utf8Bytes); + } + + /// + /// Encodes an ASN.1 PrintableString. + /// + private static byte[] EncodeAsn1PrintableString(string value) + { + byte[] asciiBytes = Encoding.ASCII.GetBytes(value); + return EncodeAsn1Tag(0x13, asciiBytes); + } + + /// + /// Encodes an ASN.1 NULL. + /// + private static byte[] EncodeAsn1Null() + { + return new byte[] { 0x05, 0x00 }; + } + + /// + /// Encodes an ASN.1 OBJECT IDENTIFIER. + /// + private static byte[] EncodeAsn1ObjectIdentifier(int[] oid) + { + if (oid == null || oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var bytes = new List(); + + // First two components are encoded as (first * 40 + second) + bytes.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + bytes.AddRange(EncodeOidComponent(oid[i])); + } + + return EncodeAsn1Tag(0x06, bytes.ToArray()); + } + + /// + /// Encodes an ASN.1 context-specific tag. + /// + private static byte[] EncodeAsn1ContextSpecific(int tagNumber, byte[] content) + { + byte tag = (byte)(0xA0 | tagNumber); // Context-specific, constructed + return EncodeAsn1Tag(tag, content); + } + + /// + /// Encodes an ASN.1 tag with length and content. + /// + private static byte[] EncodeAsn1Tag(byte tag, byte[] content) + { + byte[] lengthBytes = EncodeAsn1Length(content.Length); + byte[] result = new byte[1 + lengthBytes.Length + content.Length]; + result[0] = tag; + Array.Copy(lengthBytes, 0, result, 1, lengthBytes.Length); + Array.Copy(content, 0, result, 1 + lengthBytes.Length, content.Length); + return result; + } + + /// + /// Encodes ASN.1 length field. + /// + private static byte[] EncodeAsn1Length(int length) + { + if (length < 0x80) + { + return new byte[] { (byte)length }; + } + + var lengthBytes = new List(); + int temp = length; + while (temp > 0) + { + lengthBytes.Insert(0, (byte)(temp & 0xFF)); + temp >>= 8; + } + + byte[] result = new byte[lengthBytes.Count + 1]; + result[0] = (byte)(0x80 | lengthBytes.Count); + lengthBytes.CopyTo(result, 1); + return result; + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private static byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + + /// + /// Concatenates multiple byte arrays. + /// + private static byte[] ConcatenateByteArrays(byte[][] arrays) + { + int totalLength = 0; + foreach (byte[] array in arrays) + { + totalLength += array.Length; + } + + byte[] result = new byte[totalLength]; + int offset = 0; + foreach (byte[] array in arrays) + { + Array.Copy(array, 0, result, offset, array.Length); + offset += array.Length; + } + + return result; + } + + #endregion + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs index 04a9e06baf..a831d02c7a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CsrMetadata.cs @@ -57,13 +57,12 @@ public CsrMetadata() { } /// Validates a JSON decoded CsrMetadata instance. /// /// The CsrMetadata object. - /// false if any field is null. + /// false if any required field is null. Note: Vmid is required, Vmssid is optional. public static bool ValidateCsrMetadata(CsrMetadata csrMetadata) { if (csrMetadata == null || csrMetadata.Cuid == null || string.IsNullOrEmpty(csrMetadata.Cuid.Vmid) || - string.IsNullOrEmpty(csrMetadata.Cuid.Vmssid) || string.IsNullOrEmpty(csrMetadata.ClientId) || string.IsNullOrEmpty(csrMetadata.TenantId) || string.IsNullOrEmpty(csrMetadata.AttestationEndpoint)) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs index 9db03cc298..4d5354dcc6 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsV2ManagedIdentitySource.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Net; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; @@ -16,6 +17,7 @@ namespace Microsoft.Identity.Client.ManagedIdentity internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity { private const string CsrMetadataPath = "/metadata/identity/getPlatformMetadata"; + private const string CertificateRequestPath = "/metadata/identity/issuecredential"; public static async Task GetCsrMetadataAsync( RequestContext requestContext, @@ -29,7 +31,7 @@ public static async Task GetCsrMetadataAsync( requestContext.Logger); if (userAssignedIdQueryParam != null) { - queryParams += $"{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; + queryParams += $"&{userAssignedIdQueryParam.Value.Key}={userAssignedIdQueryParam.Value.Value}"; } var headers = new Dictionary @@ -41,7 +43,6 @@ public static async Task GetCsrMetadataAsync( IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.CsrMetadataProbe); - // CSR metadata GET request HttpResponse response = null; try @@ -50,7 +51,7 @@ public static async Task GetCsrMetadataAsync( ImdsManagedIdentitySource.GetValidatedEndpoint(requestContext.Logger, CsrMetadataPath, queryParams), headers, body: null, - method: System.Net.Http.HttpMethod.Get, + method: HttpMethod.Get, logger: requestContext.Logger, doNotThrow: false, mtlsCertificate: null, @@ -90,7 +91,7 @@ public static async Task GetCsrMetadataAsync( } } - if (!probeMode && !ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) + if (!ValidateCsrMetadataResponse(response, requestContext.Logger, probeMode)) { return null; } @@ -194,8 +195,90 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) internal ImdsV2ManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.ImdsV2) { } + private async Task ExecuteCertificateRequestAsync( + CuidInfo Cuid, + string pem) + { + var queryParams = $"cid={JsonHelper.SerializeToJson(Cuid)}"; + if (_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId != null) + { + queryParams += $"&uaid{_requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId}"; + } + queryParams += $"&api-version={ImdsManagedIdentitySource.ImdsApiVersion}"; + + var headers = new Dictionary + { + { "Metadata", "true" }, + { "x-ms-client-request-id", _requestContext.CorrelationId.ToString() } + }; + + var payload = new + { + pem = pem + }; + var body = JsonHelper.SerializeToJson(payload); + + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.Imds); + + HttpResponse response = null; + + try + { + response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( + ImdsManagedIdentitySource.GetValidatedEndpoint(_requestContext.Logger, CertificateRequestPath, queryParams), + headers, + body: new StringContent(body, System.Text.Encoding.UTF8, "application/json"), + method: HttpMethod.Post, + logger: _requestContext.Logger, + doNotThrow: false, + mtlsCertificate: null, + validateServerCertificate: null, + cancellationToken: _requestContext.UserCancellationToken, + retryPolicy: retryPolicy) + .ConfigureAwait(false); + } + catch (Exception ex) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed.", + ex, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + if (response.StatusCode != HttpStatusCode.OK) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed due to HTTP error. Status code: {response.StatusCode} Body: {response.Body}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + var certificateRequestResponse = JsonHelper.DeserializeFromJson(response.Body); + if (!CertificateRequestResponse.IsValid(certificateRequestResponse)) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + $"[ImdsV2] ImdsV2ManagedIdentitySource.ExecuteCertificateRequestAsync failed because the certificate request response is malformed. Status code: {response.StatusCode}", + null, + ManagedIdentitySource.ImdsV2, + (int)response.StatusCode); + } + + return certificateRequestResponse; + } + protected override ManagedIdentityRequest CreateRequest(string resource) { + var csrMetadata = GetCsrMetadataAsync(_requestContext, false).GetAwaiter().GetResult(); + var csr = Csr.Generate(csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.Cuid); + + var certificateRequestResponse = ExecuteCertificateRequestAsync(csrMetadata.Cuid, csr.Pem).GetAwaiter().GetResult(); + throw new NotImplementedException(); } } diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 3d89cc1bbe..d4a63354c0 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -154,6 +154,8 @@ public static HashSet s_scope public const string IdentityProvider = "my-idp"; public const string Name = "First Last"; public const string MiResourceId = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; + public const string Vmid = "test-vm-id"; + public const string Vmssid = "test-vmss-id"; public const string Claims = @"{""userinfo"":{""given_name"":{""essential"":true},""nickname"":null,""email"":{""essential"":true},""email_verified"":{""essential"":true},""picture"":null,""http://example.info/claims/groups"":null},""id_token"":{""auth_time"":{""essential"":true},""acr"":{""values"":[""urn:mace:incommon:iap:silver""]}}}"; public static readonly string[] ClientCapabilities = new[] { "cp1", "cp2" }; diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs new file mode 100644 index 0000000000..671700c100 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CsrValidator.cs @@ -0,0 +1,431 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.Utils; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + /// + /// Test helper to expose CsrValidator methods for testing malformed PEM. + /// + internal static class TestCsrValidator + { + public static byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + } + + /// + /// Helper class for validating Certificate Signing Request (CSR) content and structure. + /// + internal static class CsrValidator + { + /// + /// Validates the content of a CSR PEM string against expected values. + /// + public static void ValidateCsrContent(string pemCsr, string expectedClientId, string expectedTenantId, CuidInfo expectedCuid) + { + // Parse the CSR from PEM format + var csrData = ParseCsrFromPem(pemCsr); + + // Parse the PKCS#10 structure + var csrInfo = ParsePkcs10Structure(csrData); + + // Validate subject name + ValidateSubjectName(csrInfo.Subject, expectedClientId, expectedTenantId); + + // Validate public key + ValidatePublicKey(csrInfo.PublicKey); + + // Validate CUID attribute + ValidateCuidAttribute(csrInfo.Attributes, expectedCuid); + + // Validate signature algorithm + ValidateSignatureAlgorithm(csrInfo.SignatureAlgorithm); + } + + /// + /// Parses a PEM-formatted CSR and returns the DER bytes. + /// + private static byte[] ParseCsrFromPem(string pemCsr) + { + if (string.IsNullOrWhiteSpace(pemCsr)) + throw new ArgumentException("PEM CSR cannot be null or empty"); + + const string beginMarker = "-----BEGIN CERTIFICATE REQUEST-----"; + const string endMarker = "-----END CERTIFICATE REQUEST-----"; + + if (!pemCsr.Contains(beginMarker) || !pemCsr.Contains(endMarker)) + throw new ArgumentException("Invalid PEM format - missing CSR headers"); + + int beginIndex = pemCsr.IndexOf(beginMarker) + beginMarker.Length; + int endIndex = pemCsr.IndexOf(endMarker); + + if (beginIndex >= endIndex) + throw new ArgumentException("Invalid PEM format - malformed headers"); + + string base64Content = pemCsr.Substring(beginIndex, endIndex - beginIndex) + .Replace("\r", "").Replace("\n", "").Replace(" ", ""); + + try + { + return Convert.FromBase64String(base64Content); + } + catch (FormatException) + { + throw new FormatException("Invalid Base64 content in PEM CSR"); + } + } + + /// + /// Represents parsed PKCS#10 CSR information. + /// + private class CsrInfo + { + public byte[] Subject { get; set; } + public byte[] PublicKey { get; set; } + public byte[] Attributes { get; set; } + public byte[] SignatureAlgorithm { get; set; } + } + + /// + /// Parses the PKCS#10 ASN.1 structure and extracts key components. + /// + private static CsrInfo ParsePkcs10Structure(byte[] derBytes) + { + int offset = 0; + + // Parse outer SEQUENCE (CertificationRequest) + var outerSequence = ParseAsn1Tag(derBytes, ref offset, 0x30); + + // Reset offset to parse the CertificationRequestInfo within the outer sequence + int infoOffset = 0; + var certRequestInfo = ParseAsn1Tag(outerSequence, ref infoOffset, 0x30); + + // Parse version (should be 0) + int versionOffset = 0; + var version = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x02); + if (version.Length != 1 || version[0] != 0x00) + throw new ArgumentException("Invalid CSR version"); + + // Parse subject + var subject = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse SubjectPublicKeyInfo + var publicKey = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0x30); + + // Parse attributes (context-specific [0]) + var attributes = ParseAsn1Tag(certRequestInfo, ref versionOffset, 0xA0); + + return new CsrInfo + { + Subject = subject, + PublicKey = publicKey, + Attributes = attributes, + SignatureAlgorithm = new byte[0] // Simplified for this test + }; + } + + /// + /// Parses an ASN.1 tag and returns its content. + /// + private static byte[] ParseAsn1Tag(byte[] data, ref int offset, byte expectedTag) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data"); + + // Check tag (if expectedTag is -1, accept any tag) + if (expectedTag != 255 && data[offset] != expectedTag) + throw new ArgumentException($"Expected tag 0x{expectedTag:X2}, got 0x{data[offset]:X2}"); + + offset++; + + // Parse length + int length = ParseAsn1Length(data, ref offset); + + // Extract content + if (offset + length > data.Length) + throw new ArgumentException("Invalid ASN.1 length"); + + byte[] content = new byte[length]; + Array.Copy(data, offset, content, 0, length); + offset += length; + + return content; + } + + /// + /// Parses ASN.1 length encoding. + /// + private static int ParseAsn1Length(byte[] data, ref int offset) + { + if (offset >= data.Length) + throw new ArgumentException("Unexpected end of data in length"); + + byte firstByte = data[offset++]; + + // Short form + if ((firstByte & 0x80) == 0) + return firstByte; + + // Long form + int lengthBytes = firstByte & 0x7F; + if (lengthBytes == 0) + throw new ArgumentException("Indefinite length not supported"); + + if (offset + lengthBytes > data.Length) + throw new ArgumentException("Invalid length encoding"); + + int length = 0; + for (int i = 0; i < lengthBytes; i++) + { + length = (length << 8) | data[offset++]; + } + + return length; + } + + /// + /// Validates the subject name contains the expected client ID and tenant ID. + /// + private static void ValidateSubjectName(byte[] subjectBytes, string expectedClientId, string expectedTenantId) + { + // Subject is already a SEQUENCE of RDNs + int offset = 0; + bool foundClientId = false; + bool foundTenantId = false; + + // Parse each RDN (Relative Distinguished Name) directly from subjectBytes + while (offset < subjectBytes.Length) + { + var rdnSet = ParseAsn1Tag(subjectBytes, ref offset, 0x31); // SET + + int rdnOffset = 0; + var rdnSequence = ParseAsn1Tag(rdnSet, ref rdnOffset, 0x30); // SEQUENCE + + // Parse OID and value + int attrOffset = 0; + var oid = ParseAsn1Tag(rdnSequence, ref attrOffset, 0x06); // OID + var value = ParseAsn1Tag(rdnSequence, ref attrOffset, 255); // Any string type + + string stringValue = System.Text.Encoding.UTF8.GetString(value); + + // Check for CN (commonName) OID: 2.5.4.3 + if (IsOid(oid, new int[] { 2, 5, 4, 3 })) + { + Assert.AreEqual(expectedClientId, stringValue, "Client ID in subject CN does not match"); + foundClientId = true; + } + // Check for DC (domainComponent) OID: 0.9.2342.19200300.100.1.25 + else if (IsOid(oid, new int[] { 0, 9, 2342, 19200300, 100, 1, 25 })) + { + Assert.AreEqual(expectedTenantId, stringValue, "Tenant ID in subject DC does not match"); + foundTenantId = true; + } + } + + Assert.IsTrue(foundClientId, "Client ID (CN) not found in subject"); + Assert.IsTrue(foundTenantId, "Tenant ID (DC) not found in subject"); + } + + /// + /// Validates the public key is a valid RSA key. + /// + private static void ValidatePublicKey(byte[] publicKeyBytes) + { + // publicKeyBytes is already the SubjectPublicKeyInfo SEQUENCE content + int offset = 0; + + // Parse algorithm identifier + var algorithmId = ParseAsn1Tag(publicKeyBytes, ref offset, 0x30); + + // Parse public key bit string + var publicKeyBitString = ParseAsn1Tag(publicKeyBytes, ref offset, 0x03); + + // Validate algorithm is RSA (1.2.840.113549.1.1.1) + int algOffset = 0; + var algorithmOid = ParseAsn1Tag(algorithmId, ref algOffset, 0x06); + Assert.IsTrue(IsOid(algorithmOid, new int[] { 1, 2, 840, 113549, 1, 1, 1 }), + "Public key algorithm is not RSA"); + + // Skip the unused bits byte in bit string + if (publicKeyBitString.Length < 2 || publicKeyBitString[0] != 0x00) + throw new ArgumentException("Invalid public key bit string"); + + // Parse RSA public key (skip unused bits byte) + byte[] rsaKeyBytes = new byte[publicKeyBitString.Length - 1]; + Array.Copy(publicKeyBitString, 1, rsaKeyBytes, 0, rsaKeyBytes.Length); + + int rsaOffset = 0; + var rsaSequence = ParseAsn1Tag(rsaKeyBytes, ref rsaOffset, 0x30); + + rsaOffset = 0; + var modulus = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + var exponent = ParseAsn1Tag(rsaSequence, ref rsaOffset, 0x02); + + // Validate key size (should be 2048 bits = 256 bytes, plus potential leading zero) + Assert.IsTrue(modulus.Length >= 256 && modulus.Length <= 257, + $"RSA modulus should be 2048 bits, got {modulus.Length * 8} bits"); + + // Validate exponent (commonly 65537 = 0x010001) + Assert.IsTrue(exponent.Length >= 1 && exponent.Length <= 4, "RSA exponent has invalid length"); + } + + /// + /// Validates the CUID attribute contains the expected VM and VMSS IDs as JSON. + /// Note: Vmid is required, Vmssid is optional and will be omitted if null/empty. + /// + private static void ValidateCuidAttribute(byte[] attributesBytes, CuidInfo expectedCuid) + { + // Attributes is a SET of attributes + // We expect one attribute with challengePassword OID (1.2.840.113549.1.9.7) + + int offset = 0; + bool foundCuid = false; + + // Parse each attribute in the SET + while (offset < attributesBytes.Length) + { + var attributeSequence = ParseAsn1Tag(attributesBytes, ref offset, 0x30); + + int attrOffset = 0; + var oid = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x06); + var valueSet = ParseAsn1Tag(attributeSequence, ref attrOffset, 0x31); // SET of values + + // Check for challengePassword OID: 1.2.840.113549.1.9.7 + if (IsOid(oid, new int[] { 1, 2, 840, 113549, 1, 9, 7 })) + { + // Parse the value from the SET (should be one value) + int valueOffset = 0; + var value = ParseAsn1Tag(valueSet, ref valueOffset, 255); // Any string type + + string cuidValue = System.Text.Encoding.ASCII.GetString(value); + + // Build expected CUID value as JSON + string expectedCuidValue = BuildExpectedCuidJson(expectedCuid); + + Assert.AreEqual(expectedCuidValue, cuidValue, "CUID attribute JSON value does not match expected"); + foundCuid = true; + break; + } + } + + Assert.IsTrue(foundCuid, "CUID (challengePassword) attribute not found"); + } + + /// + /// Builds the expected CUID JSON string for validation using JsonHelper. + /// + private static string BuildExpectedCuidJson(CuidInfo expectedCuid) + { + return JsonHelper.SerializeToJson(expectedCuid); + } + + /// + /// Validates the signature algorithm is SHA256withRSA. + /// + private static void ValidateSignatureAlgorithm(byte[] signatureAlgBytes) + { + // For this test, we'll just verify that signature algorithm exists + // Full validation would require parsing the outer CSR structure + // which is more complex for this unit test scenario + Assert.IsNotNull(signatureAlgBytes, "Signature algorithm should be present"); + } + + /// + /// Checks if the given OID bytes match the expected OID components. + /// + private static bool IsOid(byte[] oidBytes, int[] expectedOid) + { + if (expectedOid.Length < 2) + return false; + + var expectedBytes = EncodeOid(expectedOid); + + if (oidBytes.Length != expectedBytes.Length) + return false; + + for (int i = 0; i < oidBytes.Length; i++) + { + if (oidBytes[i] != expectedBytes[i]) + return false; + } + + return true; + } + + /// + /// Encodes an OID from integer components to bytes (simplified version). + /// + private static byte[] EncodeOid(int[] oid) + { + if (oid.Length < 2) + throw new ArgumentException("OID must have at least 2 components"); + + var result = new System.Collections.Generic.List(); + + // First two components are encoded as (first * 40 + second) + result.AddRange(EncodeOidComponent(oid[0] * 40 + oid[1])); + + // Remaining components + for (int i = 2; i < oid.Length; i++) + { + result.AddRange(EncodeOidComponent(oid[i])); + } + + return result.ToArray(); + } + + /// + /// Encodes a single OID component using variable-length encoding. + /// + private static byte[] EncodeOidComponent(int value) + { + if (value == 0) + return new byte[] { 0x00 }; + + var bytes = new System.Collections.Generic.List(); + int temp = value; + + bytes.Insert(0, (byte)(temp & 0x7F)); + temp >>= 7; + + while (temp > 0) + { + bytes.Insert(0, (byte)((temp & 0x7F) | 0x80)); + temp >>= 7; + } + + return bytes.ToArray(); + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index e1aea27aa4..6851b425e3 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Net; using System.Threading.Tasks; using Microsoft.Identity.Client; @@ -130,5 +131,100 @@ public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIs Assert.AreEqual(ManagedIdentitySource.DefaultToImds, miSource); } } + + [TestMethod] + public void TestCsrGeneration() + { + var cuid = new CuidInfo + { + Vmid = TestConstants.Vmid, + Vmssid = TestConstants.Vmssid + }; + + // Generate CSR + var csr = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + + // Validate the CSR contents using the helper + CsrValidator.ValidateCsrContent(csr.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); + } + + [DataTestMethod] + [DataRow(null, TestConstants.TenantId)] + [DataRow("", TestConstants.TenantId)] + [DataRow(TestConstants.ClientId, null)] + [DataRow(TestConstants.ClientId, "")] + public void TestCsrGeneration_InvalidParameters(string clientId, string tenantId) + { + var cuid = new CuidInfo + { + Vmid = TestConstants.Vmid, + Vmssid = TestConstants.Vmssid + }; + + Assert.ThrowsException(() => + Csr.Generate(clientId, tenantId, cuid)); + } + + [TestMethod] + public void TestCsrGeneration_NullCuid() + { + // Test with null CUID + Assert.ThrowsException(() => + Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, null)); + } + + [DataTestMethod] + [DataRow(null, TestConstants.Vmssid)] + [DataRow("", TestConstants.Vmssid)] + public void TestCsrGeneration_InvalidVmid(string vmid, string vmssid) + { + var cuid = new CuidInfo + { + Vmid = vmid, + Vmssid = vmssid + }; + + // Should throw ArgumentException since Vmid is required + Assert.ThrowsException(() => + Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid)); + } + + [DataTestMethod] + [DataRow(TestConstants.Vmid, null)] + [DataRow(TestConstants.Vmid, "")] + public void TestCsrGeneration_OptionalVmssid(string vmid, string vmssid) + { + var cuid = new CuidInfo + { + Vmid = vmid, + Vmssid = vmssid + }; + + // Should succeed since Vmssid is optional (Vmid is provided and valid) + var csrRequest = Csr.Generate(TestConstants.ClientId, TestConstants.TenantId, cuid); + Assert.IsNotNull(csrRequest); + Assert.IsFalse(string.IsNullOrWhiteSpace(csrRequest.Pem)); + + // Validate the CSR contents - this should handle null/empty VMSSID gracefully + CsrValidator.ValidateCsrContent(csrRequest.Pem, TestConstants.ClientId, TestConstants.TenantId, cuid); + } + + [TestMethod] + public void TestCsrGeneration_MalformedPem_FormatException() + { + string malformedPem = "-----BEGIN CERTIFICATE REQUEST-----\nInvalid@#$%Base64Content!\n-----END CERTIFICATE REQUEST-----"; + Assert.ThrowsException(() => + TestCsrValidator.ParseCsrFromPem(malformedPem)); + } + + [DataTestMethod] + [DataRow("-----BEGIN CERTIFICATE-----\nTUlJQzNqQ0NBY1lDQVFBd1pURT0K\n-----END CERTIFICATE REQUEST-----")] + [DataRow("")] + [DataRow(null)] + public void TestCsrGeneration_MalformedPem_ArgumentException(string malformedPem) + { + Assert.ThrowsException(() => + TestCsrValidator.ParseCsrFromPem(malformedPem)); + } } }