Skip to content

Commit a87a36f

Browse files
authored
Add pickling support for SharedTokenCacheCredential (#36404)
1 parent 0c2a270 commit a87a36f

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
### Bugs Fixed
1010

11+
- Fixed the issue that `SharedTokenCacheCredential` was not picklable.
12+
1113
### Other Changes
1214

1315
## 1.17.1 (2024-06-21)

sdk/identity/azure-identity/azure/identity/_credentials/silent.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def __init__(
3838
validate_tenant_id(self._tenant_id)
3939
self._cache = kwargs.pop("_cache", None)
4040
self._cae_cache = kwargs.pop("_cae_cache", None)
41+
if self._cache or self._cae_cache:
42+
self._custom_cache = True
43+
else:
44+
self._custom_cache = False
4145

4246
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)
4347

@@ -162,3 +166,18 @@ def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken:
162166

163167
# cache doesn't contain a matching refresh (or access) token
164168
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
169+
170+
def __getstate__(self) -> Dict[str, Any]:
171+
state = self.__dict__.copy()
172+
# Remove the non-picklable entries
173+
if not self._custom_cache:
174+
del state["_cache"]
175+
del state["_cae_cache"]
176+
return state
177+
178+
def __setstate__(self, state: Dict[str, Any]) -> None:
179+
self.__dict__.update(state)
180+
# Re-create the unpickable entries
181+
if not self._custom_cache:
182+
self._cache = None
183+
self._cae_cache = None

sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import abc
66
import platform
77
import time
8-
from typing import Any, Iterable, List, Mapping, Optional, cast
8+
from typing import Any, Iterable, List, Mapping, Optional, cast, Dict
99
from urllib.parse import urlparse
1010
import msal
1111

@@ -96,6 +96,10 @@ def __init__(
9696
self._tenant_id = tenant_id
9797
self._cache = kwargs.pop("_cache", None)
9898
self._cae_cache = kwargs.pop("_cae_cache", None)
99+
if self._cache or self._cae_cache:
100+
self._custom_cache = True
101+
else:
102+
self._custom_cache = False
99103
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)
100104
self._client_kwargs = kwargs
101105
self._client_kwargs["tenant_id"] = "organizations"
@@ -267,3 +271,18 @@ def supported() -> bool:
267271
:rtype: bool
268272
"""
269273
return platform.system() in {"Darwin", "Linux", "Windows"}
274+
275+
def __getstate__(self) -> Dict[str, Any]:
276+
state = self.__dict__.copy()
277+
# Remove the non-picklable entries
278+
if not self._custom_cache:
279+
del state["_cache"]
280+
del state["_cae_cache"]
281+
return state
282+
283+
def __setstate__(self, state: Dict[str, Any]) -> None:
284+
self.__dict__.update(state)
285+
# Re-create the unpickable entries
286+
if not self._custom_cache:
287+
self._cache = None
288+
self._cae_cache = None

sdk/identity/azure-identity/tests/test_pickling.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import pickle
6-
from azure.identity import DefaultAzureCredential
6+
from azure.identity import DefaultAzureCredential, SharedTokenCacheCredential
77
from azure.identity._internal.msal_credentials import MsalCredential
88

99

@@ -15,6 +15,15 @@ def test_pickle_dac():
1515
data_loaded = pickle.load(infile)
1616

1717

18+
def test_pickle_shared_token_cache():
19+
cred = SharedTokenCacheCredential()
20+
cred._credential._initialize_cache()
21+
with open("data.pkl", "wb") as outfile:
22+
pickle.dump(cred, outfile)
23+
with open("data.pkl", "rb") as infile:
24+
data_loaded = pickle.load(infile)
25+
26+
1827
def test_pickle_msal_credential():
1928
cred = MsalCredential(client_id="CLIENT_ID")
2029
app = cred._get_app()

0 commit comments

Comments
 (0)