Skip to content

Commit 8b921c4

Browse files
authored
add pickling support (#34134)
* add pickling support * update * update * update * update * update * typing * black * update * update * updates * black * update * update changelog
1 parent da82787 commit 8b921c4

File tree

8 files changed

+118
-1
lines changed

8 files changed

+118
-1
lines changed

.vscode/cspell.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,12 @@
552552
"Jwcmlud"
553553
]
554554
},
555+
{
556+
"filename": "sdk/identity/azure-identity/tests/*.py",
557+
"words": [
558+
"infile"
559+
]
560+
},
555561
{
556562
"filename": "sdk/identity/test-resources*",
557563
"words": [

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Added pickling support. ([#34134](https://github.com/Azure/azure-sdk-for-python/pull/34134))
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def __init__(
5757
self._cache = cache
5858
self._cae_cache = cae_cache
5959
self._cache_options = kwargs.pop("cache_persistence_options", None)
60+
if self._cache or self._cae_cache:
61+
self._custom_cache = True
62+
else:
63+
self._custom_cache = False
6064

6165
def _get_cache(self, **kwargs: Any) -> TokenCache:
6266
cache = self._cae_cache if kwargs.get("enable_cae") else self._cache
@@ -346,6 +350,21 @@ def _post(self, data: Dict, **kwargs: Any) -> HttpRequest:
346350
url = self._get_token_url(**kwargs)
347351
return HttpRequest("POST", url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"})
348352

353+
def __getstate__(self) -> Dict[str, Any]:
354+
state = self.__dict__.copy()
355+
# Remove the non-picklable entries
356+
if not self._custom_cache:
357+
del state["_cache"]
358+
del state["_cae_cache"]
359+
return state
360+
361+
def __setstate__(self, state: Dict[str, Any]) -> None:
362+
self.__dict__.update(state)
363+
# Re-create the unpickable entries
364+
if not self._custom_cache:
365+
self._cache = None
366+
self._cae_cache = None
367+
349368

350369
def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge):
351370
# Represent capabilities as {"access_token": {"xms_cc": {"values": capabilities}}}

sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ def __init__(
2626
identity_config: Optional[Dict] = None,
2727
**kwargs: Any
2828
) -> None:
29-
self._cache = kwargs.pop("_cache", None) or TokenCache()
29+
self._custom_cache = False
30+
self._cache = kwargs.pop("_cache", None)
31+
if self._cache:
32+
self._custom_cache = True
33+
else:
34+
self._cache = TokenCache()
3035
self._content_callback = kwargs.pop("_content_callback", None)
3136
self._identity_config = identity_config or {}
3237
if client_id:
@@ -91,6 +96,19 @@ def request_token(self, *scopes, **kwargs):
9196
def _build_pipeline(self, **kwargs):
9297
pass
9398

99+
def __getstate__(self) -> Dict[str, Any]:
100+
state = self.__dict__.copy()
101+
# Remove the non-picklable entries
102+
if not self._custom_cache:
103+
del state["_cache"]
104+
return state
105+
106+
def __setstate__(self, state: Dict[str, Any]) -> None:
107+
self.__dict__.update(state)
108+
# Re-create the unpickable entries
109+
if not self._custom_cache:
110+
self._cache = TokenCache()
111+
94112

95113
class ManagedIdentityClient(ManagedIdentityClientBase):
96114
def __enter__(self) -> "ManagedIdentityClient":

sdk/identity/azure-identity/azure/identity/_internal/msal_client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,14 @@ def _store_auth_error(self, response: PipelineResponse) -> None:
131131
content = response.context.get(ContentDecodePolicy.CONTEXT_NAME)
132132
if content and "error" in content:
133133
self._local.error = (content["error"], response.http_response)
134+
135+
def __getstate__(self) -> Dict[str, Any]: # pylint:disable=client-method-name-no-double-underscore
136+
state = self.__dict__.copy()
137+
# Remove the non-picklable entries
138+
del state["_local"]
139+
return state
140+
141+
def __setstate__(self, state: Dict[str, Any]) -> None: # pylint:disable=client-method-name-no-double-underscore
142+
self.__dict__.update(state)
143+
# Re-create the unpickable entries
144+
self._local = threading.local()

sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def __init__(
5151

5252
self._cache = kwargs.pop("_cache", None)
5353
self._cae_cache = kwargs.pop("_cae_cache", None)
54+
if self._cache or self._cae_cache:
55+
self._custom_cache = True
56+
else:
57+
self._custom_cache = False
5458
self._cache_options = kwargs.pop("cache_persistence_options", None)
5559

5660
super(MsalCredential, self).__init__()
@@ -112,3 +116,22 @@ def _get_app(self, **kwargs: Any) -> msal.ClientApplication:
112116
)
113117

114118
return client_applications_map[tenant_id]
119+
120+
def __getstate__(self) -> Dict[str, Any]:
121+
state = self.__dict__.copy()
122+
# Remove the non-picklable entries
123+
del state["_client_applications"]
124+
del state["_cae_client_applications"]
125+
if not self._custom_cache:
126+
del state["_cache"]
127+
del state["_cae_cache"]
128+
return state
129+
130+
def __setstate__(self, state: Dict[str, Any]) -> None:
131+
self.__dict__.update(state)
132+
# Re-create the unpickable entries
133+
self._client_applications = {}
134+
self._cae_client_applications = {}
135+
if not self._custom_cache:
136+
self._cache = None
137+
self._cae_cache = None
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import pickle
6+
from azure.identity import DefaultAzureCredential
7+
from azure.identity._internal.msal_credentials import MsalCredential
8+
9+
10+
def test_pickle_dac():
11+
cred = DefaultAzureCredential()
12+
with open("data.pkl", "wb") as outfile:
13+
pickle.dump(cred, outfile)
14+
with open("data.pkl", "rb") as infile:
15+
data_loaded = pickle.load(infile)
16+
17+
18+
def test_pickle_msal_credential():
19+
cred = MsalCredential(client_id="CLIENT_ID")
20+
app = cred._get_app()
21+
with open("data.pkl", "wb") as outfile:
22+
pickle.dump(cred, outfile)
23+
with open("data.pkl", "rb") as infile:
24+
data_loaded = pickle.load(infile)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import pickle
6+
from azure.identity.aio import DefaultAzureCredential
7+
8+
9+
def test_pickle_dac():
10+
cred = DefaultAzureCredential()
11+
with open("data_aio.pkl", "wb") as outfile:
12+
pickle.dump(cred, outfile)
13+
with open("data_aio.pkl", "rb") as infile:
14+
data_loaded = pickle.load(infile)

0 commit comments

Comments
 (0)