Skip to content

Commit af3fcfe

Browse files
committed
[Identity] Fix issue with token requests with claims
Some credentials do not bypass the the token cache when a `get_token`/`get_token_info` call is made with `claims` provided. This can cause issues when the credential continues to serve an insufficient token. This change updates the credentials where this is an issue. Signed-off-by: Paul Van Eck <[email protected]>
1 parent a95f7a8 commit af3fcfe

File tree

7 files changed

+181
-6
lines changed

7 files changed

+181
-6
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
### Bugs Fixed
1414

1515
- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347))
16+
- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347))
1617

1718
### Other Changes
1819

sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,10 @@ def _get_token_base(
185185

186186
account = self._get_account(self._username, self._tenant_id, is_cae=is_cae)
187187

188-
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
189-
if token:
190-
return token
188+
if not claims:
189+
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
190+
if token:
191+
return token
191192

192193
# try each refresh token, returning the first access token acquired
193194
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def _initialize_cache(self, is_cae: bool = False) -> TokenCache:
9292
return cast(TokenCache, self._cae_cache if is_cae else self._cache)
9393

9494
def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessTokenInfo]:
95+
# Do not return a cached token if claims are provided.
96+
if kwargs.get("claims"):
97+
return None
9598
tenant = resolve_tenant(
9699
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
97100
)

sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ async def _get_token_base(
138138

139139
account = self._get_account(self._username, self._tenant_id, is_cae=is_cae)
140140

141-
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
142-
if token:
143-
return token
141+
if not claims:
142+
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
143+
if token:
144+
return token
144145

145146
# try each refresh token, returning the first access token acquired
146147
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):

sdk/identity/azure-identity/tests/test_aad_client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,3 +451,36 @@ def test_claims(method, args):
451451
assert post_mock.call_count == 2
452452
data, _ = post_mock.call_args
453453
assert data[0]["claims"] == cae_merged_claims
454+
455+
456+
def test_get_cached_access_token_with_claims():
457+
"""When claims are provided, get_cached_access_token should return None even if a token is cached"""
458+
459+
client_id = "client-id"
460+
scope = "scope"
461+
cached_token = "cached-access-token"
462+
tenant_id = "tenant"
463+
authority = "https://localhost/" + tenant_id
464+
claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}'
465+
466+
# Add a valid token to the cache
467+
cache = TokenCache()
468+
cache.add(
469+
{
470+
"response": build_aad_response(access_token=cached_token, expires_in=3600),
471+
"client_id": client_id,
472+
"scope": [scope],
473+
"token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")),
474+
}
475+
)
476+
477+
client = AadClient(tenant_id=tenant_id, client_id=client_id, authority=authority, cache=cache)
478+
479+
# Without claims, the cached token should be returned
480+
token = client.get_cached_access_token([scope])
481+
assert token is not None
482+
assert token.token == cached_token
483+
484+
# With claims, should return None even though a valid token is cached
485+
token_with_claims = client.get_cached_access_token([scope], claims=claims)
486+
assert token_with_claims is None

sdk/identity/azure-identity/tests/test_shared_cache_credential.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
import base64
6+
import json
7+
58
from azure.core.exceptions import ClientAuthenticationError
69
from azure.core.pipeline.policies import SansIOHTTPPolicy
710
from azure.identity import (
@@ -974,6 +977,71 @@ def test_claims_challenge(get_token_method):
974977
assert kwargs["claims_challenge"] == expected_claims
975978

976979

980+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
981+
def test_claims_skips_cached_access_token(get_token_method):
982+
"""When claims are provided, the credential should skip cached access tokens and request a new one"""
983+
984+
scope = "scope"
985+
expected_claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}'
986+
cached_access_token = "cached-access-token"
987+
first_refresh_token = "first-refresh-token"
988+
second_refresh_token = "second-refresh-token"
989+
990+
username = "[email protected]"
991+
uid = "uid"
992+
utid = "utid"
993+
994+
# Set up cache with an access token and refresh token
995+
account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token)
996+
cache = TokenCache()
997+
cache.add(account)
998+
999+
# First request without claims - this will cache an access token
1000+
transport = validating_transport(
1001+
requests=[Request(required_data={"refresh_token": first_refresh_token})],
1002+
responses=[
1003+
mock_response(
1004+
json_payload=build_aad_response(
1005+
uid=uid,
1006+
utid=utid,
1007+
access_token=cached_access_token,
1008+
refresh_token=second_refresh_token,
1009+
id_token=build_id_token(
1010+
aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username
1011+
),
1012+
)
1013+
)
1014+
],
1015+
)
1016+
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
1017+
token = getattr(credential, get_token_method)(scope)
1018+
assert token.token == cached_access_token
1019+
1020+
# Verify the access token is now cached - second request without claims should use it
1021+
credential = SharedTokenCacheCredential(
1022+
_cache=cache, transport=Mock(send=Mock(side_effect=Exception("should use cached token")))
1023+
)
1024+
token = getattr(credential, get_token_method)(scope)
1025+
assert token.token == cached_access_token
1026+
1027+
# Now request with claims - should bypass the cached access token and use refresh token
1028+
new_token_with_claims = "new-access-token-with-claims"
1029+
transport = validating_transport(
1030+
requests=[Request(required_data={"refresh_token": second_refresh_token, "claims": expected_claims})],
1031+
responses=[mock_response(json_payload=build_aad_response(access_token=new_token_with_claims))],
1032+
)
1033+
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
1034+
1035+
kwargs = {"claims": expected_claims}
1036+
if get_token_method == "get_token_info":
1037+
kwargs = {"options": kwargs}
1038+
token = getattr(credential, get_token_method)(scope, **kwargs)
1039+
1040+
# Should receive the new token with claims, not the cached one
1041+
assert token.token == new_token_with_claims
1042+
assert token.token != cached_access_token
1043+
1044+
9771045
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
9781046
def test_multitenant_authentication(get_token_method):
9791047
default_tenant = "organizations"

sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from helpers import (
2626
build_aad_response,
27+
build_id_token,
2728
id_token_claims,
2829
mock_response,
2930
get_discovery_response,
@@ -32,6 +33,7 @@
3233
)
3334
from helpers_async import async_validating_transport, AsyncMockTransport
3435
from test_shared_cache_credential import get_account_event, populated_cache
36+
from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID
3537

3638

3739
def test_supported():
@@ -719,6 +721,72 @@ async def test_initialization_with_cache_options(get_token_method):
719721
assert mock_cache_loader.call_count == 1
720722

721723

724+
@pytest.mark.asyncio
725+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
726+
async def test_claims_skips_cached_access_token(get_token_method):
727+
"""When claims are provided, the credential should skip cached access tokens and request a new one"""
728+
729+
scope = "scope"
730+
expected_claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}'
731+
cached_access_token = "cached-access-token"
732+
first_refresh_token = "first-refresh-token"
733+
second_refresh_token = "second-refresh-token"
734+
735+
username = "[email protected]"
736+
uid = "uid"
737+
utid = "utid"
738+
739+
# Set up cache with an access token and refresh token
740+
account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token)
741+
cache = TokenCache()
742+
cache.add(account)
743+
744+
# First request without claims - this will cache an access token
745+
transport = async_validating_transport(
746+
requests=[Request(required_data={"refresh_token": first_refresh_token})],
747+
responses=[
748+
mock_response(
749+
json_payload=build_aad_response(
750+
uid=uid,
751+
utid=utid,
752+
access_token=cached_access_token,
753+
refresh_token=second_refresh_token,
754+
id_token=build_id_token(
755+
aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username
756+
),
757+
)
758+
)
759+
],
760+
)
761+
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
762+
token = await getattr(credential, get_token_method)(scope)
763+
assert token.token == cached_access_token
764+
765+
# Verify the access token is now cached - second request without claims should use it
766+
credential = SharedTokenCacheCredential(
767+
_cache=cache, transport=Mock(send=Mock(side_effect=Exception("should use cached token")))
768+
)
769+
token = await getattr(credential, get_token_method)(scope)
770+
assert token.token == cached_access_token
771+
772+
# Now request with claims - should bypass the cached access token and use refresh token
773+
new_token_with_claims = "new-access-token-with-claims"
774+
transport = async_validating_transport(
775+
requests=[Request(required_data={"refresh_token": second_refresh_token, "claims": expected_claims})],
776+
responses=[mock_response(json_payload=build_aad_response(access_token=new_token_with_claims))],
777+
)
778+
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
779+
780+
kwargs = {"claims": expected_claims}
781+
if get_token_method == "get_token_info":
782+
kwargs = {"options": kwargs}
783+
token = await getattr(credential, get_token_method)(scope, **kwargs)
784+
785+
# Should receive the new token with claims, not the cached one
786+
assert token.token == new_token_with_claims
787+
assert token.token != cached_access_token
788+
789+
722790
@pytest.mark.asyncio
723791
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
724792
async def test_multitenant_authentication(get_token_method):

0 commit comments

Comments
 (0)