Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
/// <summary>
/// Encodes/decodes the persisted X.509 FriendlyName for MSAL mTLS certs.
/// Format: "MSAL|alias=cacheKey|ep=endpointBase"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give more examples here? This format is not clear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is alias, cachekey, ep and endpoint base?

/// Open the cert store and look at FriendlyName to see examples.
/// Wish we could paste a screenshot here... Maybe I can show it in code walkthroughs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

/// </summary>
internal static class FriendlyNameCodec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming: MsiCertificateFriendlyNameEncoder ?

{
public const string Prefix = "MSAL|";
public const string TagAlias = "alias";
public const string TagEp = "ep";

/// <summary>
/// Encodes alias and endpointBase into friendly name.
/// </summary>
/// <param name="alias"></param>
/// <param name="endpointBase"></param>
/// <param name="friendlyName"></param>
/// <returns></returns>
public static bool TryEncode(string alias, string endpointBase, out string friendlyName)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Avoid using these TryXYZ methods for cases where exceptions should be thrown.

{
friendlyName = null;

if (string.IsNullOrWhiteSpace(alias) || string.IsNullOrWhiteSpace(endpointBase))
return false;

alias = alias.Trim();
endpointBase = endpointBase.Trim();

// Forbid characters that would break our simple delimiter-based grammar.
if (ContainsIllegal(alias) || ContainsIllegal(endpointBase))
return false;

friendlyName = Prefix + TagAlias + "=" + alias + "|" + TagEp + "=" + endpointBase;
return true;
}

/// <summary>
/// Decodes friendly name into alias and endpointBase.
/// </summary>
/// <param name="friendlyName"></param>
/// <param name="alias"></param>
/// <param name="endpointBase"></param>
/// <returns></returns>
public static bool TryDecode(string friendlyName, out string alias, out string endpointBase)
{
alias = null;
endpointBase = null;

if (string.IsNullOrEmpty(friendlyName) ||
!friendlyName.StartsWith(Prefix, StringComparison.Ordinal))
{
return false;
}

// Example: MSAL|alias=<cacheKey>|ep=<endpointBase>
var payload = friendlyName.Substring(Prefix.Length);
var parts = payload.Split(new[] { '|' }, StringSplitOptions.RemoveEmptyEntries);

// Parse key-value pairs
foreach (var part in parts)
{
var kv = part.Split(new[] { '=' }, 2);
if (kv.Length != 2)
continue;

var k = kv[0].Trim();
var v = kv[1].Trim();

if (k.Equals(TagAlias, StringComparison.Ordinal))
{
alias = v; // minimal: last-wins
}
else if (k.Equals(TagEp, StringComparison.Ordinal))
{
endpointBase = v;
}
}

return !string.IsNullOrWhiteSpace(alias) && !string.IsNullOrWhiteSpace(endpointBase);
}

/// <summary>
/// Checks for illegal characters in alias/endpointBase.
/// Endpoint itself comes from IMDS and is well-formed, but we still validate.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not throw?

/// </summary>
/// <param name="value"></param>
/// <returns></returns>
private static bool ContainsIllegal(string value)
{
for (int i = 0; i < value.Length; i++)
{
char c = value[i];
if (c == '|' || c == '\r' || c == '\n' || c == '\0')
return true;
}
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,6 @@ private async Task<string> GetAttestationJwtAsync(
return response.AttestationToken;
}

// ...unchanged usings and class header...

/// <summary>
/// Read-through cache: try cache; if missing, run async factory once (per key),
/// store the result, and return it. Thread-safe for the given cacheKey.
Expand All @@ -457,21 +455,13 @@ private static async Task<Tuple<X509Certificate2, string, string>> GetOrCreateMt
if (factory is null)
throw new ArgumentNullException(nameof(factory));

X509Certificate2 cachedCertificate;
string cachedEndpointBase;
string cachedClientId;

// 1) Only lookup by cacheKey
// 1) In-memory cache first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 1) 2) 3) logic should be encapsualted in a different object, some sort of Cache abstraction.

if (s_mtlsCertificateCache.TryGet(cacheKey, out var cached, logger))
{
cachedCertificate = cached.Certificate;
cachedEndpointBase = cached.Endpoint;
cachedClientId = cached.ClientId;

return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
return Tuple.Create(cached.Certificate, cached.Endpoint, cached.ClientId);
}

// 2) Gate per cacheKey
// 2) Per-key gate
var gate = s_perKeyGates.GetOrAdd(cacheKey, _ => new SemaphoreSlim(1, 1));
await gate.WaitAsync(cancellationToken).ConfigureAwait(false);

Expand All @@ -480,18 +470,37 @@ private static async Task<Tuple<X509Certificate2, string, string>> GetOrCreateMt
// Re-check after acquiring the gate
if (s_mtlsCertificateCache.TryGet(cacheKey, out cached, logger))
{
cachedCertificate = cached.Certificate;
cachedEndpointBase = cached.Endpoint;
cachedClientId = cached.ClientId;
return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
return Tuple.Create(cached.Certificate, cached.Endpoint, cached.ClientId);
}

// 3) Persistent store (best-effort).
if (PersistentCertificateStore.TryFind(cacheKey, out var persisted, logger))
{
if (persisted.Certificate.HasPrivateKey)
{
var v = new CertificateCacheValue(persisted.Certificate, persisted.Endpoint, persisted.ClientId);
s_mtlsCertificateCache.Set(cacheKey, in v, logger);
return Tuple.Create(v.Certificate, v.Endpoint, v.ClientId);
}
else
{
// Not usable for mTLS; dispose clone and mint a new one
persisted.Certificate.Dispose();
logger?.Verbose(() => "[PersistentCert] Skipping persisted cert without private key; minting new.");
}
}

// 3) Mint + cache under the provided cacheKey
// 4) Mint + back-fill caches
var created = await factory().ConfigureAwait(false);

s_mtlsCertificateCache.Set(cacheKey,
new CertificateCacheValue(created.Item1, created.Item2, created.Item3),
logger);
var createdValue = new CertificateCacheValue(created.Item1, created.Item2, created.Item3);
s_mtlsCertificateCache.Set(cacheKey, in createdValue, logger);

// 5) Best-effort persist for future runs (mutex & dedup inside)
PersistentCertificateStore.TryPersist(cacheKey, created.Item1, created.Item2, created.Item3, logger);

// 6) Keep store tidy
PersistentCertificateStore.TryPruneAliasOlderThan(cacheKey, created.Item1.NotAfter.ToUniversalTime(), logger);

return created;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using Microsoft.Identity.Client.PlatformsCommon.Shared;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
/// <summary>
/// Cross-process lock based on a per-alias named mutex.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is an alias?

/// </summary>
internal static class InterprocessLock
{
public static bool TryWithAliasLock(
string alias,
TimeSpan timeout,
Action action,
Action<string> logVerbose = null)
{
var nameGlobal = GetMutexNameForAlias(alias, preferGlobal: true);
var nameLocal = GetMutexNameForAlias(alias, preferGlobal: false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in comment why you need 2 mutexes


foreach (var name in new[] { nameGlobal, nameLocal })
{
try
{
using var m = new Mutex(false, name);
bool entered;
try
{
entered = m.WaitOne(timeout);
}
catch (AbandonedMutexException)
{
entered = true; // prior holder crashed
}

if (!entered)
{
logVerbose?.Invoke($"[PersistentCert] Skip persist (lock busy '{name}').");
return false;
}

try
{ action(); }
finally
{
try
{ m.ReleaseMutex(); }
catch { /* best-effort */ }
}

return true;
}
catch (UnauthorizedAccessException)
{
logVerbose?.Invoke($"[PersistentCert] No access to mutex scope '{name}', trying next.");
continue; // try Local if Global blocked
}
catch (Exception ex)
{
logVerbose?.Invoke($"[PersistentCert] Lock failure '{name}': {ex.Message}");
return false;
}
}

return false;
}

public static string GetMutexNameForAlias(string alias, bool preferGlobal = true)
{
string suffix = HashAlias(Canonicalize(alias));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to hash? How long are the aliases? Maybe you can avoid hashing here.

return (preferGlobal ? @"Global\" : @"Local\") + "MSAL_MI_P_" + suffix;
}

private static string Canonicalize(string alias) => (alias ?? string.Empty).Trim().ToUpperInvariant();

private static string HashAlias(string s)
{
try
{
var hex = new CommonCryptographyManager().CreateSha256HashHex(s);
// Truncate to 32 chars to fit mutex name length limits
return string.IsNullOrEmpty(hex) ? "0" : (hex.Length > 32 ? hex.Substring(0, 32) : hex);
}
catch
{
return "0";
}
}
}
}
Loading