Skip to content

Commit 20ed99e

Browse files
Merge pull request #931 from AzureAD/nebharg/passTokenSourceMI
Pass token source and update tests
2 parents 3699125 + 42eee6a commit 20ed99e

File tree

4 files changed

+121
-46
lines changed

4 files changed

+121
-46
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class AcquireTokenByManagedIdentitySupplier extends AuthenticationResultSupplier
1313

1414
private static final Logger LOG = LoggerFactory.getLogger(AcquireTokenByManagedIdentitySupplier.class);
1515

16-
private static final int TWO_HOURS = 2*3600;
16+
private static final int TWO_HOURS = 2 * 3600;
1717

1818
private ManagedIdentityParameters managedIdentityParameters;
1919

@@ -37,49 +37,66 @@ AuthenticationResult execute() throws Exception {
3737
clientApplication.serviceBundle()
3838
);
3939

40-
if (!managedIdentityParameters.forceRefresh) {
41-
LOG.debug("ForceRefresh set to false. Attempting cache lookup");
42-
43-
try {
44-
Set<String> scopes = new HashSet<>();
45-
scopes.add(this.managedIdentityParameters.resource);
46-
SilentParameters parameters = SilentParameters
47-
.builder(scopes)
48-
.tenant(managedIdentityParameters.tenant())
49-
.build();
50-
51-
RequestContext context = new RequestContext(
52-
this.clientApplication,
53-
PublicApi.ACQUIRE_TOKEN_SILENTLY,
54-
parameters);
55-
56-
SilentRequest silentRequest = new SilentRequest(
57-
parameters,
58-
this.clientApplication,
59-
context,
60-
null);
61-
62-
AcquireTokenSilentSupplier supplier = new AcquireTokenSilentSupplier(
63-
this.clientApplication,
64-
silentRequest);
65-
66-
return supplier.execute();
67-
} catch (MsalClientException ex) {
68-
if (ex.errorCode().equals(AuthenticationErrorCode.CACHE_MISS)) {
69-
LOG.debug(String.format("Cache lookup failed: %s", ex.getMessage()));
70-
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor);
71-
} else {
72-
LOG.error(String.format("Error occurred while cache lookup: %s", ex.getMessage()));
73-
throw ex;
74-
}
75-
}
40+
CacheRefreshReason cacheRefreshReason = CacheRefreshReason.NOT_APPLICABLE;
41+
42+
if (managedIdentityParameters.forceRefresh) {
43+
LOG.debug("ForceRefresh set to true. Skipping cache lookup and attempting to acquire new token");
44+
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor, CacheRefreshReason.FORCE_REFRESH);
7645
}
7746

78-
LOG.info("Skipped looking for an Access Token in the cache because forceRefresh or Claims were set. ");
79-
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor);
47+
48+
LOG.debug("ForceRefresh set to false. Attempting cache lookup");
49+
try {
50+
Set<String> scopes = new HashSet<>();
51+
scopes.add(this.managedIdentityParameters.resource);
52+
SilentParameters parameters = SilentParameters
53+
.builder(scopes)
54+
.tenant(managedIdentityParameters.tenant())
55+
.build();
56+
57+
RequestContext context = new RequestContext(
58+
this.clientApplication,
59+
PublicApi.ACQUIRE_TOKEN_SILENTLY,
60+
parameters);
61+
62+
SilentRequest silentRequest = new SilentRequest(
63+
parameters,
64+
this.clientApplication,
65+
context,
66+
null);
67+
68+
AcquireTokenSilentSupplier supplier = new AcquireTokenSilentSupplier(
69+
this.clientApplication,
70+
silentRequest);
71+
72+
AuthenticationResult result = supplier.execute();
73+
cacheRefreshReason = SilentRequestHelper.getCacheRefreshReasonIfApplicable(
74+
parameters,
75+
result,
76+
LOG);
77+
78+
// If the token does not need a refresh, return the cached token
79+
// Else refresh the token if it is either expired, proactively refreshable, or if the claims are passed.
80+
if (cacheRefreshReason == CacheRefreshReason.NOT_APPLICABLE) {
81+
LOG.debug("Returning token from cache");
82+
result.metadata().tokenSource(TokenSource.CACHE);
83+
return result;
84+
} else {
85+
LOG.debug(String.format("Refreshing access token. Cache refresh reason: %s", cacheRefreshReason));
86+
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor, cacheRefreshReason);
87+
}
88+
} catch (MsalClientException ex) {
89+
if (ex.errorCode().equals(AuthenticationErrorCode.CACHE_MISS)) {
90+
LOG.debug(String.format("Cache lookup failed: %s", ex.getMessage()));
91+
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor, cacheRefreshReason);
92+
} else {
93+
LOG.error(String.format("Error occurred while cache lookup: %s", ex.getMessage()));
94+
throw ex;
95+
}
96+
}
8097
}
8198

82-
private AuthenticationResult fetchNewAccessTokenAndSaveToCache(TokenRequestExecutor tokenRequestExecutor) {
99+
private AuthenticationResult fetchNewAccessTokenAndSaveToCache(TokenRequestExecutor tokenRequestExecutor, CacheRefreshReason cacheRefreshReason) throws Exception {
83100

84101
ManagedIdentityClient managedIdentityClient = new ManagedIdentityClient(msalRequest, tokenRequestExecutor.getServiceBundle());
85102

@@ -91,13 +108,17 @@ private AuthenticationResult fetchNewAccessTokenAndSaveToCache(TokenRequestExecu
91108

92109
AuthenticationResult authenticationResult = createFromManagedIdentityResponse(managedIdentityResponse);
93110
clientApplication.tokenCache.saveTokens(tokenRequestExecutor, authenticationResult, clientApplication.authenticationAuthority.host);
94-
return authenticationResult;
111+
AuthenticationResult result = authenticationResult;
112+
result.metadata().tokenSource(TokenSource.IDENTITY_PROVIDER);
113+
result.metadata().cacheRefreshReason(cacheRefreshReason);
114+
return result;
95115
}
96116

97117
private AuthenticationResult createFromManagedIdentityResponse(ManagedIdentityResponse managedIdentityResponse) {
98118
long expiresOn = Long.parseLong(managedIdentityResponse.expiresOn);
99119
long refreshOn = calculateRefreshOn(expiresOn);
100120
AuthenticationResultMetadata metadata = AuthenticationResultMetadata.builder()
121+
.tokenSource(TokenSource.IDENTITY_PROVIDER)
101122
.refreshOn(refreshOn)
102123
.build();
103124

@@ -111,7 +132,7 @@ private AuthenticationResult createFromManagedIdentityResponse(ManagedIdentityRe
111132
.build();
112133
}
113134

114-
private long calculateRefreshOn(long expiresOn){
135+
private long calculateRefreshOn(long expiresOn) {
115136
long timestampSeconds = System.currentTimeMillis() / 1000;
116137
long expiresIn = expiresOn - timestampSeconds;
117138

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenSilentSupplier.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ AuthenticationResult execute() throws Exception {
7373
}
7474
}
7575
}
76+
7677
if (res == null || StringHelper.isBlank(res.accessToken())) {
7778
throw new MsalClientException(AuthenticationErrorMessage.NO_TOKEN_IN_CACHE, AuthenticationErrorCode.CACHE_MISS);
7879
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import java.util.Date;
7+
import org.slf4j.Logger;
8+
9+
class SilentRequestHelper {
10+
11+
private static final int ACCESS_TOKEN_EXPIRE_BUFFER_IN_SEC = 5 * 60;
12+
13+
private SilentRequestHelper() {
14+
// Utility class
15+
}
16+
17+
static CacheRefreshReason getCacheRefreshReasonIfApplicable(SilentParameters parameters, AuthenticationResult cachedResult, Logger log) {
18+
// If the request contains claims then the token should be refreshed, to ensure that the returned token has the correct claims
19+
// Note: these are the types of claims found in (for example) a claims challenge, and do not include client capabilities
20+
if (parameters.claims() != null) {
21+
log.debug(String.format("Refreshing access token. Cache refresh reason: %s", CacheRefreshReason.CLAIMS));
22+
return CacheRefreshReason.CLAIMS;
23+
}
24+
25+
long currTimeStampSec = new Date().getTime() / 1000;
26+
27+
// If the access token is expired or within 5 minutes of becoming expired, refresh it
28+
if (!StringHelper.isBlank(cachedResult.accessToken()) && cachedResult.expiresOn() < (currTimeStampSec + ACCESS_TOKEN_EXPIRE_BUFFER_IN_SEC)) {
29+
log.debug(String.format("Refreshing access token. Cache refresh reason: %s", CacheRefreshReason.EXPIRED));
30+
return CacheRefreshReason.EXPIRED;
31+
}
32+
33+
// Certain long-lived tokens will have a 'refresh on' time that indicates a refresh should be attempted long before the token would expire
34+
if (!StringHelper.isBlank(cachedResult.accessToken()) &&
35+
cachedResult.refreshOn() != null && cachedResult.refreshOn() > 0 &&
36+
cachedResult.refreshOn() < currTimeStampSec && cachedResult.expiresOn() >= (currTimeStampSec + ACCESS_TOKEN_EXPIRE_BUFFER_IN_SEC)){
37+
log.debug(String.format("Refreshing access token. Cache refresh reason: %s", CacheRefreshReason.PROACTIVE_REFRESH));
38+
return CacheRefreshReason.PROACTIVE_REFRESH;
39+
}
40+
41+
// If there is a refresh token but no access token, we should use the refresh token to get the access token
42+
if (StringHelper.isBlank(cachedResult.accessToken()) && !StringHelper.isBlank(cachedResult.refreshToken())) {
43+
log.debug(String.format("Refreshing access token. Cache refresh reason: %s", CacheRefreshReason.NO_CACHED_ACCESS_TOKEN));
44+
return CacheRefreshReason.NO_CACHED_ACCESS_TOKEN;
45+
}
46+
47+
return CacheRefreshReason.NOT_APPLICABLE;
48+
}
49+
}

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,14 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource
191191
.build()).get();
192192

193193
assertNotNull(result.accessToken());
194-
195-
String accessToken = result.accessToken();
194+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
196195

197196
result = miApp.acquireTokenForManagedIdentity(
198197
ManagedIdentityParameters.builder(resource)
199198
.build()).get();
200199

201200
assertNotNull(result.accessToken());
202-
assertEquals(accessToken, result.accessToken());
201+
assertEquals(TokenSource.CACHE, result.metadata().tokenSource());
203202
verify(httpClientMock, times(1)).send(any());
204203
}
205204

@@ -228,6 +227,7 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy
228227
.build()).get();
229228

230229
assertNotNull(result.accessToken());
230+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
231231
verify(httpClientMock, times(1)).send(any());
232232
}
233233

@@ -253,6 +253,7 @@ void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception {
253253
long timestampSeconds = (System.currentTimeMillis() / 1000);
254254

255255
assertNotNull(result.accessToken());
256+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
256257
assertEquals((result.expiresOn() - timestampSeconds)/2, result.refreshOn() - timestampSeconds);
257258

258259
verify(httpClientMock, times(1)).send(any());
@@ -320,14 +321,15 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT
320321
.build()).get();
321322

322323
assertNotNull(result.accessToken());
324+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
323325

324326
result = miApp.acquireTokenForManagedIdentity(
325327
ManagedIdentityParameters.builder(anotherResource)
326328
.build()).get();
327329

328330
assertNotNull(result.accessToken());
331+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
329332
verify(httpClientMock, times(2)).send(any());
330-
// TODO: Assert token source to check the token source is IDP and not Cache.
331333
}
332334

333335
@ParameterizedTest
@@ -565,12 +567,14 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
565567
.build()).get();
566568

567569
assertNotNull(resultMiApp1.accessToken());
570+
assertEquals(TokenSource.IDENTITY_PROVIDER, resultMiApp1.metadata().tokenSource());
568571

569572
IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity(
570573
ManagedIdentityParameters.builder(resource)
571574
.build()).get();
572575

573576
assertNotNull(resultMiApp2.accessToken());
577+
assertEquals(TokenSource.CACHE, resultMiApp2.metadata().tokenSource());
574578

575579
//acquireTokenForManagedIdentity does a cache lookup by default, and all ManagedIdentityApplication's share a cache,
576580
// so calling acquireTokenForManagedIdentity with the same parameters in two different ManagedIdentityApplications

0 commit comments

Comments
 (0)