Skip to content

Commit b3e716b

Browse files
authored
[Identity] Fix issue with cae_cache not being used (#42145)
Signed-off-by: Paul Van Eck <[email protected]>
1 parent ff82059 commit b3e716b

File tree

6 files changed

+29
-12
lines changed

6 files changed

+29
-12
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
### Bugs Fixed
1010

11+
- Fixed an issue where CAE (Continuous Access Evaluation) caches were not properly used by `AuthorizationCodeCredential` and the asynchronous `OnBehalfOfCredential`. ([#42145](https://github.com/Azure/azure-sdk-for-python/pull/42145))
12+
1113
### Other Changes
1214

1315
## 1.24.0b1 (2025-07-17)

sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo:
127127
return token
128128

129129
token = None
130-
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
130+
for refresh_token in self._client.get_cached_refresh_tokens(scopes, **kwargs):
131131
if "secret" in refresh_token:
132132
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
133133
if token:

sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
133133
return token
134134

135135
token = cast(AccessTokenInfo, None)
136-
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
136+
for refresh_token in self._client.get_cached_refresh_tokens(scopes, **kwargs):
137137
if "secret" in refresh_token:
138138
token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
139139
if token:

sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
119119
# Note we assume the cache has tokens for one user only. That's okay because each instance of this class is
120120
# locked to a single user (assertion). This assumption will become unsafe if this class allows applications
121121
# to change an instance's assertion.
122-
refresh_tokens = self._client.get_cached_refresh_tokens(scopes)
122+
refresh_tokens = self._client.get_cached_refresh_tokens(scopes, **kwargs)
123123
if len(refresh_tokens) == 1: # there should be only one
124124
try:
125125
refresh_token = refresh_tokens[0]["secret"]

sdk/identity/azure-identity/tests/test_auth_code.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def test_tenant_id(get_token_method):
8080

8181

8282
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
83-
def test_auth_code_credential(get_token_method):
83+
@pytest.mark.parametrize("enable_cae", [True, False])
84+
def test_auth_code_credential(get_token_method, enable_cae):
8485
client_id = "client id"
8586
secret = "fake-client-secret"
8687
tenant_id = "tenant"
@@ -118,6 +119,7 @@ def test_auth_code_credential(get_token_method):
118119
responses=[mock_response(json_payload=auth_response)] * 2,
119120
)
120121
cache = msal.TokenCache()
122+
cae_cache = msal.TokenCache()
121123

122124
credential = AuthorizationCodeCredential(
123125
client_id=client_id,
@@ -127,22 +129,29 @@ def test_auth_code_credential(get_token_method):
127129
redirect_uri=redirect_uri,
128130
transport=transport,
129131
cache=cache,
132+
cae_cache=cae_cache,
130133
)
131134

132135
# first call should redeem the auth code
133-
token = getattr(credential, get_token_method)(expected_scope)
136+
kwargs = {"enable_cae": enable_cae}
137+
if get_token_method == "get_token_info":
138+
kwargs = {"options": kwargs}
139+
token = getattr(credential, get_token_method)(expected_scope, **kwargs)
134140
assert token.token == expected_access_token
135141
assert transport.send.call_count == 1
136142

137143
# no auth code -> credential should return cached token
138-
token = getattr(credential, get_token_method)(expected_scope)
144+
token = getattr(credential, get_token_method)(expected_scope, **kwargs)
139145
assert token.token == expected_access_token
140146
assert transport.send.call_count == 1
141147

142148
# no auth code, no cached token -> credential should redeem refresh token
143-
cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0]
144-
cache.remove_at(cached_access_token)
145-
token = getattr(credential, get_token_method)(expected_scope)
149+
cache_being_used = cae_cache if enable_cae else cache
150+
cached_tokens = list(cache_being_used.search(cache_being_used.CredentialType.ACCESS_TOKEN))
151+
assert cached_tokens
152+
cached_access_token = cached_tokens[0]
153+
cache_being_used.remove_at(cached_access_token)
154+
token = getattr(credential, get_token_method)(expected_scope, **kwargs)
146155
assert token.token == expected_access_token
147156
assert transport.send.call_count == 2
148157

sdk/identity/azure-identity/tests/test_obo_async.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def test_invalid_cert():
249249

250250
@pytest.mark.asyncio
251251
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
252-
async def test_refresh_token(get_token_method):
252+
@pytest.mark.parametrize("enable_cae", [True, False])
253+
async def test_refresh_token(get_token_method, enable_cae):
253254
first_token = "***"
254255
second_token = first_token * 2
255256
refresh_token = "refresh-token"
@@ -274,10 +275,15 @@ async def send(request, **kwargs):
274275
credential = OnBehalfOfCredential(
275276
"tenant-id", "client-id", client_secret="secret", user_assertion="assertion", transport=Mock(send=send)
276277
)
277-
token = await getattr(credential, get_token_method)("scope")
278+
279+
kwargs = {"enable_cae": enable_cae}
280+
if get_token_method == "get_token_info":
281+
kwargs = {"options": kwargs}
282+
283+
token = await getattr(credential, get_token_method)("scope", **kwargs)
278284
assert token.token == first_token
279285

280-
token = await getattr(credential, get_token_method)("scope")
286+
token = await getattr(credential, get_token_method)("scope", **kwargs)
281287
assert token.token == second_token
282288

283289
assert requests == 2

0 commit comments

Comments
 (0)