Skip to content

Commit 46293fe

Browse files
authored
[Identity] Convert remaining cache.find usage (#36734)
Since find is deprecated, search is the API we should be using. Signed-off-by: Paul Van Eck <[email protected]>
1 parent e5b455b commit 46293fe

11 files changed

+34
-31
lines changed

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optio
8787
)
8888

8989
cache = self._get_cache(**kwargs)
90-
tokens = cache.find(
90+
for token in cache.search(
9191
TokenCache.CredentialType.ACCESS_TOKEN,
9292
target=list(scopes),
9393
query={"client_id": self._client_id, "realm": tenant},
94-
)
95-
for token in tokens:
94+
):
9695
expires_on = int(token["expires_on"])
9796
if expires_on > int(time.time()):
9897
return AccessToken(token["secret"], expires_on)
@@ -101,7 +100,7 @@ def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optio
101100
def get_cached_refresh_tokens(self, scopes: Iterable[str], **kwargs) -> List[Dict]:
102101
# Assumes all cached refresh tokens belong to the same user
103102
cache = self._get_cache(**kwargs)
104-
return cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))
103+
return list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes)))
105104

106105
@abc.abstractmethod
107106
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
@@ -140,17 +139,21 @@ def _process_response(self, response: PipelineResponse, request_time: int, **kwa
140139
if response.http_request.body.get("grant_type") == "refresh_token":
141140
if content.get("error") == "invalid_grant":
142141
# the request's refresh token is invalid -> evict it from the cache
143-
cache_entries = cache.find(
144-
TokenCache.CredentialType.REFRESH_TOKEN,
145-
query={"secret": response.http_request.body["refresh_token"]},
142+
cache_entries = list(
143+
cache.search(
144+
TokenCache.CredentialType.REFRESH_TOKEN,
145+
query={"secret": response.http_request.body["refresh_token"]},
146+
)
146147
)
147148
for invalid_token in cache_entries:
148149
cache.remove_rt(invalid_token)
149150
if "refresh_token" in content:
150151
# Microsoft Entra ID returned a new refresh token -> update the cache entry
151-
cache_entries = cache.find(
152-
TokenCache.CredentialType.REFRESH_TOKEN,
153-
query={"secret": response.http_request.body["refresh_token"]},
152+
cache_entries = list(
153+
cache.search(
154+
TokenCache.CredentialType.REFRESH_TOKEN,
155+
query={"secret": response.http_request.body["refresh_token"]},
156+
)
154157
)
155158
# If the old token is in multiple cache entries, the cache is in a state we don't
156159
# expect or know how to reason about, so we update nothing.

sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def _get_cache_items_for_authority(
157157
:rtype: list[CacheItem]
158158
"""
159159

160-
cache = self._cae_cache if is_cae else self._cache
160+
cache = cast(msal.TokenCache, self._cae_cache if is_cae else self._cache)
161161
items = []
162-
for item in cache.find(credential_type):
162+
for item in cache.search(credential_type):
163163
environment = item.get("environment")
164164
if environment in self._environment_aliases:
165165
items.append(item)
@@ -232,9 +232,9 @@ def _get_cached_access_token(
232232
if "home_account_id" not in account:
233233
return None
234234

235-
cache = self._cae_cache if is_cae else self._cache
235+
cache = cast(msal.TokenCache, self._cae_cache if is_cae else self._cache)
236236
try:
237-
cache_entries = cache.find(
237+
cache_entries = cache.search(
238238
msal.TokenCache.CredentialType.ACCESS_TOKEN,
239239
target=list(scopes),
240240
query={"home_account_id": account["home_account_id"]},
@@ -253,9 +253,9 @@ def _get_refresh_tokens(self, account, is_cae: bool = False) -> List[str]:
253253
if "home_account_id" not in account:
254254
return []
255255

256-
cache = self._cae_cache if is_cae else self._cache
256+
cache = cast(msal.TokenCache, self._cae_cache if is_cae else self._cache)
257257
try:
258-
cache_entries = cache.find(
258+
cache_entries = cache.search(
259259
msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account["home_account_id"]}
260260
)
261261
return [token["secret"] for token in cache_entries if "secret" in token]

sdk/identity/azure-identity/tests/test_aad_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def test_evicts_invalid_refresh_token():
202202
cache = TokenCache()
203203
cache.add({"response": build_aad_response(uid="id1", utid="tid1", access_token="*", refresh_token=invalid_token)})
204204
cache.add({"response": build_aad_response(uid="id2", utid="tid2", access_token="*", refresh_token="...")})
205-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2
206-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 1
205+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN))) == 2
206+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token}))) == 1
207207

208208
def send(request, **_):
209209
assert request.data["refresh_token"] == invalid_token
@@ -216,8 +216,8 @@ def send(request, **_):
216216
client.obtain_token_by_refresh_token(scopes=("scope",), refresh_token=invalid_token)
217217

218218
assert transport.send.call_count == 1
219-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
220-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0
219+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN))) == 1
220+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token}))) == 0
221221

222222

223223
def test_retries_token_requests():

sdk/identity/azure-identity/tests/test_aad_client_async.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ async def test_evicts_invalid_refresh_token():
196196
cache = TokenCache()
197197
cache.add({"response": build_aad_response(uid="id1", utid="tid1", access_token="*", refresh_token=invalid_token)})
198198
cache.add({"response": build_aad_response(uid="id2", utid="tid2", access_token="*", refresh_token="...")})
199-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 2
200-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 1
199+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN))) == 2
200+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token}))) == 1
201201

202202
async def send(request, **_):
203203
assert request.data["refresh_token"] == invalid_token
@@ -210,8 +210,8 @@ async def send(request, **_):
210210
await client.obtain_token_by_refresh_token(scopes=("scope",), refresh_token=invalid_token)
211211

212212
assert transport.send.call_count == 1
213-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN)) == 1
214-
assert len(cache.find(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token})) == 0
213+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN))) == 1
214+
assert len(list(cache.search(TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": invalid_token}))) == 0
215215

216216

217217
async def test_retries_token_requests():

sdk/identity/azure-identity/tests/test_auth_code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_auth_code_credential():
136136
assert transport.send.call_count == 1
137137

138138
# no auth code, no cached token -> credential should redeem refresh token
139-
cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0]
139+
cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0]
140140
cache.remove_at(cached_access_token)
141141
token = credential.get_token(expected_scope)
142142
assert token.token == expected_access_token

sdk/identity/azure-identity/tests/test_auth_code_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ async def test_auth_code_credential():
160160
assert transport.send.call_count == 1
161161

162162
# no auth code, no cached token -> credential should redeem refresh token
163-
cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0]
163+
cached_access_token = list(cache.search(cache.CredentialType.ACCESS_TOKEN))[0]
164164
cache.remove_at(cached_access_token)
165165
token = await credential.get_token(expected_scope)
166166
assert token.token == expected_access_token

sdk/identity/azure-identity/tests/test_certificate_credential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_persistent_cache_multiple_clients(cert_path, cert_password):
414414
assert token_b.token == access_token_b
415415
assert transport_b.send.call_count == 2
416416

417-
assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
417+
assert len(list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 2
418418

419419

420420
def test_certificate_arguments():

sdk/identity/azure-identity/tests/test_certificate_credential_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ async def test_persistent_cache_multiple_clients(cert_path, cert_password):
334334
assert transport_b.send.call_count == 1
335335
assert mock_cache_loader.call_count == 2
336336

337-
assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
337+
assert len(list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 2
338338

339339

340340
def test_certificate_arguments():

sdk/identity/azure-identity/tests/test_client_secret_credential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_cache_multiple_clients():
260260
assert token_b.token == access_token_b
261261
assert transport_b.send.call_count == 2
262262

263-
assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
263+
assert len(list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 2
264264

265265

266266
def test_multitenant_authentication():

sdk/identity/azure-identity/tests/test_client_secret_credential_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ async def test_cache_multiple_clients():
304304
assert transport_b.send.call_count == 1
305305
assert mock_cache_loader.call_count == 2
306306

307-
assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2
307+
assert len(list(cache.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 2
308308

309309

310310
@pytest.mark.asyncio

0 commit comments

Comments
 (0)