Skip to content

Commit da01336

Browse files
committed
Add access_token_sha256_to_refresh
1 parent 0340f5e commit da01336

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tests/test_application.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
22
# so this test_application file contains only unit tests without dependency.
3+
import hashlib
34
import json
45
import logging
56
import sys
@@ -56,6 +57,35 @@ def test_bytes_to_bytes(self):
5657
self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes"))
5758

5859

60+
def fake_token_getter(
61+
*,
62+
access_token: str = "an access token",
63+
status_code: int = 200,
64+
expires_in: int = 3600,
65+
token_type: str = "Bearer",
66+
payload: dict = None,
67+
headers: dict = None,
68+
):
69+
"""A helper to create a fake token getter,
70+
which will be consumed by ClientApplication's acquire methods' post parameter.
71+
72+
Generic mock.patch() is inconvenient because:
73+
1. If you patch it at or above oauth2.py _obtain_token(), token cache is not populated.
74+
2. If you patch it at request.post(), your test cases become fragile because
75+
more http round-trips may be added for future flows,
76+
then your existing test case would break until you mock new round-trips.
77+
"""
78+
return lambda url, *args, **kwargs: MinimalResponse(
79+
status_code=status_code,
80+
text=json.dumps(payload or {
81+
"access_token": access_token,
82+
"expires_in": expires_in,
83+
"token_type": token_type,
84+
}),
85+
headers=headers,
86+
)
87+
88+
5989
class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):
6090

6191
def setUp(self):
@@ -856,3 +886,30 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
856886
)
857887
self.assertEqual(result.get("error"), "broker_error")
858888

889+
890+
@patch("msal.authority.tenant_discovery", new=Mock(return_value={
891+
"authorization_endpoint": "https://contoso.com/placeholder",
892+
"token_endpoint": "https://contoso.com/placeholder",
893+
}))
894+
class AccessTokenToRefreshTestCase(unittest.TestCase):
895+
def test_mismatching_hash_should_not_trigger_refresh(self):
896+
scopes = ["scope"]
897+
old_token = "old AT"
898+
new_token = "new AT"
899+
app = msal.ConfidentialClientApplication("foo", client_credential="bar")
900+
app.acquire_token_for_client(scopes, post=fake_token_getter(access_token=old_token))
901+
self.assertNotEqual(app.token_cache._cache, {}, "Cache should have been populated")
902+
903+
result = app.acquire_token_for_client(
904+
scopes,
905+
access_token_sha256_to_refresh="mismatching hash",
906+
post=fake_token_getter(access_token=new_token))
907+
self.assertEqual(result.get("access_token"), old_token, "Should hit old token")
908+
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_CACHE)
909+
910+
result = app.acquire_token_for_client(
911+
scopes,
912+
access_token_sha256_to_refresh=hashlib.sha256(old_token.encode()).hexdigest(),
913+
post=fake_token_getter(access_token=new_token))
914+
self.assertEqual(result.get("access_token"), new_token, "Should obtain new token")
915+
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_IDP)

0 commit comments

Comments
 (0)