Skip to content

Commit 0922465

Browse files
committed
Add access_token_sha256_to_refresh
1 parent 137dee4 commit 0922465

File tree

3 files changed

+118
-4
lines changed

3 files changed

+118
-4
lines changed

msal/application.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import functools
2+
import hashlib
23
import json
34
import time
45
import logging
56
import sys
67
import warnings
78
from threading import Lock
8-
from typing import Optional # Needed in Python 3.7 & 3.8
9+
from typing import List, Optional # Needed in Python 3.7 & 3.8
910
from urllib.parse import urlparse
1011
import os
1112

@@ -1520,12 +1521,17 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15201521
correlation_id=None,
15211522
http_exceptions=None,
15221523
auth_scheme=None,
1524+
*,
1525+
access_token_sha256_to_refresh: Optional[str] = None,
15231526
**kwargs):
15241527
# This internal method has two calling patterns:
15251528
# it accepts a non-empty account to find token for a user,
15261529
# and accepts account=None to find a token for the current app.
15271530
access_token_from_cache = None
1528-
if not (force_refresh or claims_challenge or auth_scheme): # Then attempt AT cache
1531+
if access_token_sha256_to_refresh and force_refresh:
1532+
raise ValueError(
1533+
"access_token_sha256_to_refresh and force_refresh are mutually exclusive")
1534+
if access_token_sha256_to_refresh or not (force_refresh or auth_scheme): # Then attempt AT cache
15291535
query={
15301536
"client_id": self.client_id,
15311537
"environment": authority.instance,
@@ -1549,6 +1555,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15491555
if expires_in < 5*60: # Then consider it expired
15501556
refresh_reason = msal.telemetry.AT_EXPIRED
15511557
continue # Removal is not necessary, it will be overwritten
1558+
if access_token_sha256_to_refresh and hashlib.sha256(
1559+
entry["secret"].encode()
1560+
).hexdigest() == access_token_sha256_to_refresh:
1561+
refresh_reason = msal.telemetry.AT_REJECTED
1562+
continue # Might have another useful AT in next loop
15521563
logger.debug("Cache hit an AT")
15531564
access_token_from_cache = { # Mimic a real response
15541565
"access_token": entry["secret"],
@@ -2347,7 +2358,14 @@ class ConfidentialClientApplication(ClientApplication): # server-side web app
23472358
except that ``allow_broker`` parameter shall remain ``None``.
23482359
"""
23492360

2350-
def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
2361+
def acquire_token_for_client(
2362+
self,
2363+
scopes: List[str],
2364+
claims_challenge: Optional[str] = None,
2365+
*,
2366+
access_token_sha256_to_refresh: Optional[str] = None,
2367+
**kwargs
2368+
):
23512369
"""Acquires token for the current confidential client, not for an end user.
23522370
23532371
Since MSAL Python 1.23, it will automatically look for token from cache,
@@ -2360,6 +2378,11 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23602378
in the form of a claims_challenge directive in the www-authenticate header to be
23612379
returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token.
23622380
It is a string of a JSON object which contains lists of claims being requested from these locations.
2381+
:param access_token_sha256_to_refresh:
2382+
If you have an access token that is known to be rejected by the resource,
2383+
you can provide its sha256 here so that MSAL will refresh that specific token.
2384+
2385+
New in version 1.32.0.
23632386
23642387
:return: A dict representing the json response from Microsoft Entra:
23652388
@@ -2371,7 +2394,9 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
23712394
"Historically, this method does not support force_refresh behavior. "
23722395
)
23732396
return _clean_up(self._acquire_token_silent_with_error(
2374-
scopes, None, claims_challenge=claims_challenge, **kwargs))
2397+
scopes, account=None, claims_challenge=claims_challenge,
2398+
access_token_sha256_to_refresh=access_token_sha256_to_refresh,
2399+
**kwargs))
23752400

23762401
def _acquire_token_for_client(
23772402
self,

msal/telemetry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AT_EXPIRED = 3
1414
AT_AGING = 4
1515
RESERVED = 5
16+
AT_REJECTED = 6
1617

1718

1819
def _get_new_correlation_id():

tests/test_application.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
22
# so this test_application file contains only unit tests without dependency.
3+
import hashlib
34
import json
45
import logging
56
import sys
@@ -62,6 +63,35 @@ def test_bytes_to_bytes(self):
6263
self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes"))
6364

6465

66+
def fake_token_getter(
67+
*,
68+
access_token: str = "an access token",
69+
status_code: int = 200,
70+
expires_in: int = 3600,
71+
token_type: str = "Bearer",
72+
payload: dict = None,
73+
headers: dict = None,
74+
):
75+
"""A helper to create a fake token getter,
76+
which will be consumed by ClientApplication's acquire methods' post parameter.
77+
78+
Generic mock.patch() is inconvenient because:
79+
1. If you patch it at or above oauth2.py _obtain_token(), token cache is not populated.
80+
2. If you patch it at request.post(), your test cases become fragile because
81+
more http round-trips may be added for future flows,
82+
then your existing test case would break until you mock new round-trips.
83+
"""
84+
return lambda url, *args, **kwargs: MinimalResponse(
85+
status_code=status_code,
86+
text=json.dumps(payload or {
87+
"access_token": access_token,
88+
"expires_in": expires_in,
89+
"token_type": token_type,
90+
}),
91+
headers=headers,
92+
)
93+
94+
6595
class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):
6696

6797
@patch(_OIDC_DISCOVERY, new=_OIDC_DISCOVERY_MOCK)
@@ -874,3 +904,61 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
874904
parent_window_handle=app.CONSOLE_WINDOW_HANDLE,
875905
)
876906
self.assertEqual(result.get("error"), "broker_error")
907+
908+
909+
@patch("msal.authority.tenant_discovery", new=Mock(return_value={
910+
"authorization_endpoint": "https://contoso.com/placeholder",
911+
"token_endpoint": "https://contoso.com/placeholder",
912+
}))
913+
class AccessTokenToRefreshTestCase(unittest.TestCase):
914+
scopes = ["scope"]
915+
token1 = "AT one"
916+
token1_hash = hashlib.sha256(token1.encode()).hexdigest()
917+
token2 = "AT two"
918+
919+
def setUp(self):
920+
self.app = msal.ConfidentialClientApplication("id", client_credential="*")
921+
# Prepopulate cache
922+
self.app.acquire_token_for_client(
923+
self.scopes, post=fake_token_getter(access_token=self.token1))
924+
self.assertNotEqual(
925+
self.app.token_cache._cache, {}, "Cache should have been populated")
926+
927+
def test_mismatching_hash_should_not_trigger_refresh(self):
928+
result = self.app.acquire_token_for_client(
929+
self.scopes,
930+
access_token_sha256_to_refresh="mismatching hash",
931+
post=fake_token_getter(access_token=self.token2))
932+
self.assertEqual(result.get("access_token"), self.token1, "Should hit old token")
933+
self.assertEqual(result.get("token_source"), self.app._TOKEN_SOURCE_CACHE)
934+
935+
def test_matching_hash_should_trigger_refresh(self):
936+
result = self.app.acquire_token_for_client(
937+
self.scopes,
938+
access_token_sha256_to_refresh=self.token1_hash,
939+
post=fake_token_getter(access_token=self.token2))
940+
self.assertEqual(result.get("access_token"), self.token2, "Should obtain new token")
941+
self.assertEqual(result.get("token_source"), self.app._TOKEN_SOURCE_IDP)
942+
943+
# A client using old token1 and valid old hash, even with claims challenge,
944+
# should not trigger refresh, because we can serve it with token2 in cache.
945+
result = self.app.acquire_token_for_client(
946+
self.scopes,
947+
access_token_sha256_to_refresh=self.token1_hash,
948+
claims_challenge='''{"access_token": {
949+
"access_token": {"nbf": {"essential": true, "value": "1563308371"}
950+
}}''',
951+
post=fake_token_getter(access_token="AT three"))
952+
self.assertEqual(result.get("access_token"), self.token2, "Token 2 should be returned")
953+
self.assertEqual(result.get("token_source"), self.app._TOKEN_SOURCE_CACHE)
954+
955+
def test_force_refresh_alone_should_trigger_refresh(self):
956+
# Note: MSAL Python's acquire_token_for_client() never support force_refresh,
957+
# but let's ensure _acquire_token_silent_with_error(..., force_refresh=True)
958+
# bypasses cache, so that other account-based flows can still use it.
959+
result = self.app._acquire_token_silent_with_error(
960+
self.scopes, account=None, force_refresh=True,
961+
post=fake_token_getter(access_token=self.token2))
962+
self.assertEqual(result.get("access_token"), self.token2, "Should hit new token")
963+
self.assertEqual(result.get("token_source"), self.app._TOKEN_SOURCE_IDP)
964+

0 commit comments

Comments
 (0)