Skip to content

Commit 50a5034

Browse files
committed
Add access_token_sha256_to_refresh
1 parent a6d3d0d commit 50a5034

File tree

3 files changed

+101
-3
lines changed

3 files changed

+101
-3
lines changed

msal/application.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import hashlib
23
import json
34
import time
45
import logging
@@ -1520,12 +1521,14 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15201521
correlation_id=None,
15211522
http_exceptions=None,
15221523
auth_scheme=None,
1524+
*,
1525+
access_token_sha256_to_refresh: Optional[str] = None,
15231526
**kwargs):
15241527
# This internal method has two calling patterns:
15251528
# it accepts a non-empty account to find token for a user,
15261529
# and accepts account=None to find a token for the current app.
15271530
access_token_from_cache = None
1528-
if not (force_refresh or claims_challenge or auth_scheme): # Then attempt AT cache
1531+
if access_token_sha256_to_refresh or not (force_refresh or auth_scheme): # Then attempt AT cache
15291532
query={
15301533
"client_id": self.client_id,
15311534
"environment": authority.instance,
@@ -1549,6 +1552,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15491552
if expires_in < 5*60: # Then consider it expired
15501553
refresh_reason = msal.telemetry.AT_EXPIRED
15511554
continue # Removal is not necessary, it will be overwritten
1555+
if access_token_sha256_to_refresh and hashlib.sha256(
1556+
entry["secret"].encode()
1557+
).hexdigest() == access_token_sha256_to_refresh:
1558+
refresh_reason = msal.telemetry.AT_REJECTED
1559+
continue # Might have another useful AT in next loop
15521560
logger.debug("Cache hit an AT")
15531561
access_token_from_cache = { # Mimic a real response
15541562
"access_token": entry["secret"],
@@ -2347,7 +2355,14 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
23472355
except that ``allow_broker`` parameter shall remain ``None``.
23482356
"""
23492357

2350-
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
2358+
def acquire_token_for_client(
2359+
self,
2360+
scopes: list[str],
2361+
claims_challenge: Optional[str] = None,
2362+
*,
2363+
access_token_sha256_to_refresh: Optional[str] = None,
2364+
**kwargs
2365+
):
23512366
"""Acquires token for the current confidential client, not for an end user.
23522367
23532368
Since MSAL Python 1.23, it will automatically look for token from cache,
@@ -2360,6 +2375,11 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23602375
in the form of a claims_challenge directive in the www-authenticate header to be
23612376
returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token.
23622377
It is a string of a JSON object which contains lists of claims being requested from these locations.
2378+
:param access_token_sha256_to_refresh:
2379+
If you have an access token that is known to be rejected by the resource,
2380+
you can provide its sha256 here so that MSAL will refresh that specific token.
2381+
2382+
New in version 1.32.0.
23632383
23642384
:return: A dict representing the json response from Microsoft Entra:
23652385
@@ -2371,7 +2391,9 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23712391
"Historically, this method does not support force_refresh behavior. "
23722392
)
23732393
return _clean_up(self._acquire_token_silent_with_error(
2374-
scopes, None, claims_challenge=claims_challenge, **kwargs))
2394+
scopes, account=None, claims_challenge=claims_challenge,
2395+
access_token_sha256_to_refresh=access_token_sha256_to_refresh,
2396+
**kwargs))
23752397

23762398
def _acquire_token_for_client(
23772399
self,

msal/telemetry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AT_EXPIRED = 3
1414
AT_AGING = 4
1515
RESERVED = 5
16+
AT_REJECTED = 6
1617

1718

1819
def _get_new_correlation_id():

tests/test_application.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
22
# so this test_application file contains only unit tests without dependency.
3+
import hashlib
34
import json
45
import logging
56
import sys
@@ -62,6 +63,35 @@ def test_bytes_to_bytes(self):
6263
self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes"))
6364

6465

66+
def fake_token_getter(
67+
*,
68+
access_token: str = "an access token",
69+
status_code: int = 200,
70+
expires_in: int = 3600,
71+
token_type: str = "Bearer",
72+
payload: dict = None,
73+
headers: dict = None,
74+
):
75+
"""A helper to create a fake token getter,
76+
which will be consumed by ClientApplication's acquire methods' post parameter.
77+
78+
Generic mock.patch() is inconvenient because:
79+
1. If you patch it at or above oauth2.py _obtain_token(), token cache is not populated.
80+
2. If you patch it at request.post(), your test cases become fragile because
81+
more http round-trips may be added for future flows,
82+
then your existing test case would break until you mock new round-trips.
83+
"""
84+
return lambda url, *args, **kwargs: MinimalResponse(
85+
status_code=status_code,
86+
text=json.dumps(payload or {
87+
"access_token": access_token,
88+
"expires_in": expires_in,
89+
"token_type": token_type,
90+
}),
91+
headers=headers,
92+
)
93+
94+
6595
class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):
6696

6797
@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
@@ -874,3 +904,48 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
874904
parent_window_handle=app.CONSOLE_WINDOW_HANDLE,
875905
)
876906
self.assertEqual(result.get("error"), "broker_error")
907+
908+
909+
@patch("msal.authority.tenant_discovery", new=Mock(return_value={
910+
"authorization_endpoint": "https://contoso.com/placeholder",
911+
"token_endpoint": "https://contoso.com/placeholder",
912+
}))
913+
class AccessTokenToRefreshTestCase(unittest.TestCase):
914+
def test_mismatching_hash_should_not_trigger_refresh(self):
915+
scopes = ["scope"]
916+
token1 = "AT one"
917+
token1_hash = hashlib.sha256(token1.encode()).hexdigest()
918+
token2 = "AT two"
919+
app = msal.ConfidentialClientApplication("foo", client_credential="bar")
920+
921+
# Prepopulate cache
922+
app.acquire_token_for_client(scopes, post=fake_token_getter(access_token=token1))
923+
self.assertNotEqual(app.token_cache._cache, {}, "Cache should have been populated")
924+
925+
# Test mismatching hash should not trigger refresh
926+
result = app.acquire_token_for_client(
927+
scopes,
928+
access_token_sha256_to_refresh="mismatching hash",
929+
post=fake_token_getter(access_token=token2))
930+
self.assertEqual(result.get("access_token"), token1, "Should hit old token")
931+
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_CACHE)
932+
933+
# Test matching hash should trigger refresh
934+
result = app.acquire_token_for_client(
935+
scopes,
936+
access_token_sha256_to_refresh=token1_hash,
937+
post=fake_token_getter(access_token=token2))
938+
self.assertEqual(result.get("access_token"), token2, "Should obtain new token")
939+
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_IDP)
940+
941+
# A client using old token1 and matching hash, even with claims challenge,
942+
# should not trigger refresh, because we can serve it with token2 in cache.
943+
result = app.acquire_token_for_client(
944+
scopes,
945+
access_token_sha256_to_refresh=token1_hash,
946+
claims_challenge='''{"access_token": {
947+
"access_token": {"nbf": {"essential": true, "value": "1563308371"}
948+
}}''',
949+
post=fake_token_getter(access_token="AT three"))
950+
self.assertEqual(result.get("access_token"), token2, "Token 2 should be returned")
951+
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_CACHE)

0 commit comments

Comments
 (0)