Skip to content

Commit b415de2

Browse files
authored
Adding MSAL token caching support to ManagedIdentityCredential (Azure#30432)
* updating to use native broker impl * upgrade to latest msal * adding conf client caching to managed identity cred * disable manual test * removing local nuget feed * adding caching and refresh tests
1 parent a7813ed commit b415de2

File tree

4 files changed

+85
-3
lines changed

4 files changed

+85
-3
lines changed

sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Linq;
6+
using System.Net.Http.Headers;
57
using System.Threading;
68
using System.Threading.Tasks;
79
using Azure.Core;
10+
using Microsoft.Identity.Client;
11+
using Microsoft.Identity.Client.Extensibility;
812

913
namespace Azure.Identity
1014
{
@@ -14,6 +18,7 @@ internal class ManagedIdentityClient
1418
"ManagedIdentityCredential authentication unavailable. No Managed Identity endpoint found.";
1519

1620
private Lazy<ManagedIdentitySource> _identitySource;
21+
private MsalConfidentialClient _msal;
1722

1823
protected ManagedIdentityClient()
1924
{
@@ -40,18 +45,35 @@ public ManagedIdentityClient(ManagedIdentityClientOptions options)
4045
ClientId = options.ClientId;
4146
Pipeline = options.Pipeline;
4247
_identitySource = new Lazy<ManagedIdentitySource>(() => SelectManagedIdentitySource(options));
48+
_msal = new MsalConfidentialClient(Pipeline, "MANAGED-IDENTITY-RESOURCE-TENENT", ClientId ?? "SYSTEM-ASSIGNED-MANAGED-IDENTITY", AppTokenProviderImpl, options.Options);
4349
}
4450

4551
internal CredentialPipeline Pipeline { get; }
4652

4753
protected string ClientId { get; }
4854

49-
public virtual async ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context,
55+
public async ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
56+
{
57+
AuthenticationResult result = await _msal.AcquireTokenForClientAsync(context.Scopes, context.TenantId, async, cancellationToken).ConfigureAwait(false);
58+
59+
return new AccessToken(result.AccessToken, result.ExpiresOn);
60+
}
61+
62+
public virtual async ValueTask<AccessToken> AuthenticateCoreAsync(bool async, TokenRequestContext context,
5063
CancellationToken cancellationToken)
5164
{
5265
return await _identitySource.Value.AuthenticateAsync(async, context, cancellationToken).ConfigureAwait(false);
5366
}
5467

68+
private async Task<AppTokenProviderResult> AppTokenProviderImpl(AppTokenProviderParameters parameters)
69+
{
70+
TokenRequestContext requestContext = new TokenRequestContext(parameters.Scopes.ToArray(), claims: parameters.Claims);
71+
72+
AccessToken token = await AuthenticateCoreAsync(true, requestContext, parameters.CancellationToken).ConfigureAwait(false);
73+
74+
return new AppTokenProviderResult() { AccessToken = token.Token, ExpiresInSeconds = Math.Max(Convert.ToInt64((token.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds), 1) };
75+
}
76+
5577
private static ManagedIdentitySource SelectManagedIdentitySource(ManagedIdentityClientOptions options)
5678
{
5779
return

sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77
using System.Threading;
88
using System.Threading.Tasks;
99
using Microsoft.Identity.Client;
10+
using Microsoft.Identity.Client.Extensibility;
1011

1112
namespace Azure.Identity
1213
{
1314
internal class MsalConfidentialClient : MsalClientBase<IConfidentialClientApplication>
1415
{
16+
private const string s_instanceMetadata = "{\"tenant_discovery_endpoint\":\"https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration\",\"api-version\":\"1.1\",\"metadata\":[{\"preferred_network\":\"login.microsoftonline.com\",\"preferred_cache\":\"login.windows.net\",\"aliases\":[\"login.microsoftonline.com\",\"login.windows.net\",\"login.microsoft.com\",\"sts.windows.net\"]}]}";
1517
internal readonly string _clientSecret;
1618
internal readonly bool _includeX5CClaimHeader;
1719
internal readonly IX509Certificate2Provider _certificateProvider;
1820
private readonly Func<string> _assertionCallback;
1921
private readonly Func<CancellationToken, Task<string>> _asyncAssertionCallback;
22+
private readonly Func<AppTokenProviderParameters, Task<AppTokenProviderResult>> _appTokenProviderCallback;
2023

2124
internal string RedirectUrl { get; }
2225

@@ -52,15 +55,32 @@ public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, stri
5255
_asyncAssertionCallback = assertionCallback;
5356
}
5457

58+
public MsalConfidentialClient(CredentialPipeline pipeline, string tenantId, string clientId, Func<AppTokenProviderParameters, Task<AppTokenProviderResult>> appTokenProviderCallback, TokenCredentialOptions options)
59+
: base(pipeline, tenantId, clientId, options)
60+
{
61+
_appTokenProviderCallback = appTokenProviderCallback;
62+
}
63+
5564
internal string RegionalAuthority { get; } = EnvironmentVariables.AzureRegionalAuthorityName;
5665

5766
protected override async ValueTask<IConfidentialClientApplication> CreateClientAsync(bool async, CancellationToken cancellationToken)
5867
{
5968
ConfidentialClientApplicationBuilder confClientBuilder = ConfidentialClientApplicationBuilder.Create(ClientId)
60-
.WithAuthority(Pipeline.AuthorityHost.AbsoluteUri, TenantId)
6169
.WithHttpClientFactory(new HttpPipelineClientFactory(Pipeline.HttpPipeline))
6270
.WithLogging(LogMsal, enablePiiLogging: IsPiiLoggingEnabled);
6371

72+
//special case for using appTokenProviderCallback, authority validation and instance metadata discovery should be disabled since we're not calling the STS
73+
if (_appTokenProviderCallback != null)
74+
{
75+
confClientBuilder.WithAppTokenProvider(_appTokenProviderCallback)
76+
.WithAuthority(Pipeline.AuthorityHost.AbsoluteUri, TenantId, false)
77+
.WithInstanceDiscoveryMetadata(s_instanceMetadata);
78+
}
79+
else
80+
{
81+
confClientBuilder.WithAuthority(Pipeline.AuthorityHost.AbsoluteUri, TenantId);
82+
}
83+
6484
if (_clientSecret != null)
6585
{
6686
confClientBuilder.WithClientSecret(_clientSecret);

sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,46 @@ public ManagedIdentityCredentialTests(bool isAsync) : base(isAsync)
2929

3030
private const string ExpectedToken = "mock-msi-access-token";
3131

32+
[Test]
33+
public async Task VerifyTokenCaching()
34+
{
35+
int callCount = 0;
36+
37+
var mockClient = new MockManagedIdentityClient(CredentialPipeline.GetInstance(null))
38+
{
39+
TokenFactory = () => { callCount++; return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddHours(24)); }
40+
};
41+
42+
var cred = InstrumentClient(new ManagedIdentityCredential(mockClient));
43+
44+
for (int i = 0; i < 5; i++)
45+
{
46+
await cred.GetTokenAsync(new TokenRequestContext(MockScopes.Default));
47+
}
48+
49+
Assert.AreEqual(1, callCount);
50+
}
51+
52+
[Test]
53+
public async Task VerifyExpiringTokenRefresh()
54+
{
55+
int callCount = 0;
56+
57+
var mockClient = new MockManagedIdentityClient(CredentialPipeline.GetInstance(null))
58+
{
59+
TokenFactory = () => { callCount++; return new AccessToken(Guid.NewGuid().ToString(), DateTimeOffset.UtcNow.AddMinutes(2)); }
60+
};
61+
62+
var cred = InstrumentClient(new ManagedIdentityCredential(mockClient));
63+
64+
for (int i = 0; i < 5; i++)
65+
{
66+
await cred.GetTokenAsync(new TokenRequestContext(MockScopes.Default));
67+
}
68+
69+
Assert.AreEqual(5, callCount);
70+
}
71+
3272
[NonParallelizable]
3373
[Test]
3474
public async Task VerifyImdsRequestWithClientIdMockAsync()

sdk/identity/Azure.Identity/tests/Mock/MockManagedIdentityClient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public MockManagedIdentityClient(CredentialPipeline pipeline, string clientId)
2727

2828
public Func<AccessToken> TokenFactory { get; set; }
2929

30-
public override ValueTask<AccessToken> AuthenticateAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
30+
public override ValueTask<AccessToken> AuthenticateCoreAsync(bool async, TokenRequestContext context, CancellationToken cancellationToken)
3131
=> TokenFactory != null ? new ValueTask<AccessToken>(TokenFactory()) : base.AuthenticateAsync(async, context, cancellationToken);
3232
}
3333
}

0 commit comments

Comments
 (0)