Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
### Bugs Fixed

- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347))
- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552))

### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ def _get_token_base(

account = self._get_account(self._username, self._tenant_id, is_cae=is_cae)

token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
if token:
return token
if not claims:
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
if token:
return token

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def _initialize_cache(self, is_cae: bool = False) -> TokenCache:
return cast(TokenCache, self._cae_cache if is_cae else self._cache)

def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessTokenInfo]:
# Do not return a cached token if claims are provided.
if kwargs.get("claims"):
return None
tenant = resolve_tenant(
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ async def _get_token_base(

account = self._get_account(self._username, self._tenant_id, is_cae=is_cae)

token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
if token:
return token
if not claims:
token = self._get_cached_access_token(scopes, account, is_cae=is_cae)
if token:
return token

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account, is_cae=is_cae):
Expand Down
33 changes: 33 additions & 0 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,36 @@ def test_claims(method, args):
assert post_mock.call_count == 2
data, _ = post_mock.call_args
assert data[0]["claims"] == cae_merged_claims


def test_get_cached_access_token_with_claims():
"""When claims are provided, get_cached_access_token should return None even if a token is cached"""

client_id = "client-id"
scope = "scope"
cached_token = "cached-access-token"
tenant_id = "tenant"
authority = "https://localhost/" + tenant_id
claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}'

# Add a valid token to the cache
cache = TokenCache()
cache.add(
{
"response": build_aad_response(access_token=cached_token, expires_in=3600),
"client_id": client_id,
"scope": [scope],
"token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")),
}
)

client = AadClient(tenant_id=tenant_id, client_id=client_id, authority=authority, cache=cache)

# Without claims, the cached token should be returned
token = client.get_cached_access_token([scope])
assert token is not None
assert token.token == cached_token

# With claims, should return None even though a valid token is cached
token_with_claims = client.get_cached_access_token([scope], claims=claims)
assert token_with_claims is None
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,71 @@ def test_claims_challenge(get_token_method):
assert kwargs["claims_challenge"] == expected_claims


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_claims_skips_cached_access_token(get_token_method):
"""When claims are provided, the credential should skip cached access tokens and request a new one"""

scope = "scope"
expected_claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}'
cached_access_token = "cached-access-token"
first_refresh_token = "first-refresh-token"
second_refresh_token = "second-refresh-token"

username = "[email protected]"
uid = "uid"
utid = "utid"

# Set up cache with an access token and refresh token
account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token)
cache = TokenCache()
cache.add(account)

# First request without claims - this will cache an access token
transport = validating_transport(
requests=[Request(required_data={"refresh_token": first_refresh_token})],
responses=[
mock_response(
json_payload=build_aad_response(
uid=uid,
utid=utid,
access_token=cached_access_token,
refresh_token=second_refresh_token,
id_token=build_id_token(
aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username
),
)
)
],
)
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
token = getattr(credential, get_token_method)(scope)
assert token.token == cached_access_token

# Verify the access token is now cached - second request without claims should use it
credential = SharedTokenCacheCredential(
_cache=cache, transport=Mock(send=Mock(side_effect=Exception("should use cached token")))
)
token = getattr(credential, get_token_method)(scope)
assert token.token == cached_access_token

# Now request with claims - should bypass the cached access token and use refresh token
new_token_with_claims = "new-access-token-with-claims"
transport = validating_transport(
requests=[Request(required_data={"refresh_token": second_refresh_token, "claims": expected_claims})],
responses=[mock_response(json_payload=build_aad_response(access_token=new_token_with_claims))],
)
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)

kwargs = {"claims": expected_claims}
if get_token_method == "get_token_info":
kwargs = {"options": kwargs}
token = getattr(credential, get_token_method)(scope, **kwargs)

# Should receive the new token with claims, not the cached one
assert token.token == new_token_with_claims
assert token.token != cached_access_token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_multitenant_authentication(get_token_method):
default_tenant = "organizations"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from helpers import (
build_aad_response,
build_id_token,
id_token_claims,
mock_response,
get_discovery_response,
Expand All @@ -32,6 +33,7 @@
)
from helpers_async import async_validating_transport, AsyncMockTransport
from test_shared_cache_credential import get_account_event, populated_cache
from azure.identity._constants import DEVELOPER_SIGN_ON_CLIENT_ID


def test_supported():
Expand Down Expand Up @@ -719,6 +721,72 @@ async def test_initialization_with_cache_options(get_token_method):
assert mock_cache_loader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_claims_skips_cached_access_token(get_token_method):
"""When claims are provided, the credential should skip cached access tokens and request a new one"""

scope = "scope"
expected_claims = '{"access_token": {"nbf": {"essential": true, "value": "1234567890"}}}'
cached_access_token = "cached-access-token"
first_refresh_token = "first-refresh-token"
second_refresh_token = "second-refresh-token"

username = "[email protected]"
uid = "uid"
utid = "utid"

# Set up cache with an access token and refresh token
account = get_account_event(username=username, uid=uid, utid=utid, refresh_token=first_refresh_token)
cache = TokenCache()
cache.add(account)

# First request without claims - this will cache an access token
transport = async_validating_transport(
requests=[Request(required_data={"refresh_token": first_refresh_token})],
responses=[
mock_response(
json_payload=build_aad_response(
uid=uid,
utid=utid,
access_token=cached_access_token,
refresh_token=second_refresh_token,
id_token=build_id_token(
aud=DEVELOPER_SIGN_ON_CLIENT_ID, object_id=uid, tenant_id=utid, username=username
),
)
)
],
)
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)
token = await getattr(credential, get_token_method)(scope)
assert token.token == cached_access_token

# Verify the access token is now cached - second request without claims should use it
credential = SharedTokenCacheCredential(
_cache=cache, transport=Mock(send=Mock(side_effect=Exception("should use cached token")))
)
token = await getattr(credential, get_token_method)(scope)
assert token.token == cached_access_token

# Now request with claims - should bypass the cached access token and use refresh token
new_token_with_claims = "new-access-token-with-claims"
transport = async_validating_transport(
requests=[Request(required_data={"refresh_token": second_refresh_token, "claims": expected_claims})],
responses=[mock_response(json_payload=build_aad_response(access_token=new_token_with_claims))],
)
credential = SharedTokenCacheCredential(_cache=cache, transport=transport)

kwargs = {"claims": expected_claims}
if get_token_method == "get_token_info":
kwargs = {"options": kwargs}
token = await getattr(credential, get_token_method)(scope, **kwargs)

# Should receive the new token with claims, not the cached one
assert token.token == new_token_with_claims
assert token.token != cached_access_token


@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_multitenant_authentication(get_token_method):
Expand Down
Loading