diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 061a653e837e..157821d16612 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -13,6 +13,7 @@ ### Bugs Fixed - Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347)) +- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) ### Other Changes diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 0ef0e5b6d3da..4ee7aa187ee7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -185,9 +185,10 @@ def _get_token_base( account = self._get_account(self._username, self._tenant_id, is_cae=is_cae) - token = self._get_cached_access_token(scopes, account, is_cae=is_cae) - if token: - return token + if not claims: + token = self._get_cached_access_token(scopes, account, is_cae=is_cae) + if token: + return token # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 876ba56a5c7e..b40574d8f2e6 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -92,6 +92,9 @@ def _initialize_cache(self, is_cae: bool = False) -> TokenCache: return cast(TokenCache, self._cae_cache if is_cae else self._cache) def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessTokenInfo]: + # Do not return a cached token if claims are provided. + if kwargs.get("claims"): + return None tenant = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 0d14c508ec9b..2c27af171517 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -138,9 +138,10 @@ async def _get_token_base( account = self._get_account(self._username, self._tenant_id, is_cae=is_cae) - token = self._get_cached_access_token(scopes, account, is_cae=is_cae) - if token: - return token + if not claims: + token = self._get_cached_access_token(scopes, account, is_cae=is_cae) + if token: + return token # try each refresh token, returning the first access token acquired for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae): diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 6facdafb4629..56148e39ce72 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -451,3 +451,36 @@ def test_claims(method, args): assert post_mock.call_count == 2 data, _ = post_mock.call_args assert data[0]["claims"] == cae_merged_claims + + +def test_get_cached_access_token_with_claims(): + """When claims are provided, get_cached_access_token should return None even if a token is cached""" + + client_id = "client-id" + scope = "scope" + cached_token = "cached-access-token" + tenant_id = "tenant" + authority = "https://localhost/" + tenant_id + claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}' + + # Add a valid token to the cache + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=cached_token, expires_in=3600), + "client_id": client_id, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")), + } + ) + + client = AadClient(tenant_id=tenant_id, client_id=client_id, authority=authority, cache=cache) + + # Without claims, the cached token should be returned + token = client.get_cached_access_token([scope]) + assert token is not None + assert token.token == cached_token + + # With claims, should return None even though a valid token is cached + token_with_claims = client.get_cached_access_token([scope], claims=claims) + assert token_with_claims is None diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 5ab9c08db482..f1c64da124a9 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -974,6 +974,71 @@ def test_claims_challenge(get_token_method): assert kwargs["claims_challenge"] == expected_claims +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +def test_claims_skips_cached_access_token(get_token_method): + """When claims are provided, the credential should skip cached access tokens and request a new one""" + + scope = "scope" + expected_claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}' + cached_access_token = "cached-access-token" + first_refresh_token = "first-refresh-token" + second_refresh_token = "second-refresh-token" + + username = "user@example.com" + uid = "uid" + utid = "utid" + + # Set up cache with an access token and refresh token + account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token) + cache = TokenCache() + cache.add(account) + + # First request without claims - this will cache an access token + transport = validating_transport( + requests=[Request(required_data={"refresh_token": first_refresh_token})], + responses=[ + mock_response( + json_payload=build_aad_response( + uid=uid, + utid=utid, + access_token=cached_access_token, + refresh_token=second_refresh_token, + id_token=build_id_token( + aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username + ), + ) + ) + ], + ) + credential = SharedTokenCacheCredential(_cache=cache, transport=transport) + token = getattr(credential, get_token_method)(scope) + assert token.token == cached_access_token + + # Verify the access token is now cached - second request without claims should use it + credential = SharedTokenCacheCredential( + _cache=cache, transport=Mock(send=Mock(side_effect=Exception("should use cached token"))) + ) + token = getattr(credential, get_token_method)(scope) + assert token.token == cached_access_token + + # Now request with claims - should bypass the cached access token and use refresh token + new_token_with_claims = "new-access-token-with-claims" + transport = validating_transport( + requests=[Request(required_data={"refresh_token": second_refresh_token, "claims": expected_claims})], + responses=[mock_response(json_payload=build_aad_response(access_token=new_token_with_claims))], + ) + credential = SharedTokenCacheCredential(_cache=cache, transport=transport) + + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = getattr(credential, get_token_method)(scope, **kwargs) + + # Should receive the new token with claims, not the cached one + assert token.token == new_token_with_claims + assert token.token != cached_access_token + + @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) def test_multitenant_authentication(get_token_method): default_tenant = "organizations" diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index d8ce17cf370e..da51952c13bb 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -24,6 +24,7 @@ from helpers import ( build_aad_response, + build_id_token, id_token_claims, mock_response, get_discovery_response, @@ -32,6 +33,7 @@ ) from helpers_async import async_validating_transport, AsyncMockTransport from test_shared_cache_credential import get_account_event, populated_cache +from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID def test_supported(): @@ -719,6 +721,72 @@ async def test_initialization_with_cache_options(get_token_method): assert mock_cache_loader.call_count == 1 +@pytest.mark.asyncio +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_claims_skips_cached_access_token(get_token_method): + """When claims are provided, the credential should skip cached access tokens and request a new one""" + + scope = "scope" + expected_claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}' + cached_access_token = "cached-access-token" + first_refresh_token = "first-refresh-token" + second_refresh_token = "second-refresh-token" + + username = "user@example.com" + uid = "uid" + utid = "utid" + + # Set up cache with an access token and refresh token + account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token) + cache = TokenCache() + cache.add(account) + + # First request without claims - this will cache an access token + transport = async_validating_transport( + requests=[Request(required_data={"refresh_token": first_refresh_token})], + responses=[ + mock_response( + json_payload=build_aad_response( + uid=uid, + utid=utid, + access_token=cached_access_token, + refresh_token=second_refresh_token, + id_token=build_id_token( + aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username + ), + ) + ) + ], + ) + credential = SharedTokenCacheCredential(_cache=cache, transport=transport) + token = await getattr(credential, get_token_method)(scope) + assert token.token == cached_access_token + + # Verify the access token is now cached - second request without claims should use it + credential = SharedTokenCacheCredential( + _cache=cache, transport=Mock(send=Mock(side_effect=Exception("should use cached token"))) + ) + token = await getattr(credential, get_token_method)(scope) + assert token.token == cached_access_token + + # Now request with claims - should bypass the cached access token and use refresh token + new_token_with_claims = "new-access-token-with-claims" + transport = async_validating_transport( + requests=[Request(required_data={"refresh_token": second_refresh_token, "claims": expected_claims})], + responses=[mock_response(json_payload=build_aad_response(access_token=new_token_with_claims))], + ) + credential = SharedTokenCacheCredential(_cache=cache, transport=transport) + + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": kwargs} + token = await getattr(credential, get_token_method)(scope, **kwargs) + + # Should receive the new token with claims, not the cached one + assert token.token == new_token_with_claims + assert token.token != cached_access_token + + @pytest.mark.asyncio @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) async def test_multitenant_authentication(get_token_method):