Skip to content

Commit ef54724

Browse files
authored
Refactor SharedTokenCacheCredential (Azure#19914)
1 parent 4192b56 commit ef54724

File tree

3 files changed

+154
-105
lines changed

3 files changed

+154
-105
lines changed

sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py

Lines changed: 35 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,27 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
import os
6-
import time
7-
8-
from msal.application import PublicClientApplication
9-
10-
from azure.core.credentials import AccessToken
11-
from azure.core.exceptions import ClientAuthenticationError
5+
from typing import TYPE_CHECKING
126

7+
from .silent import SilentAuthenticationCredential
138
from .. import CredentialUnavailableError
149
from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID
15-
from .._internal import AadClient, resolve_tenant, validate_tenant_id
16-
from .._internal.decorators import log_get_token, wrap_exceptions
17-
from .._internal.msal_client import MsalClient
10+
from .._internal import AadClient
11+
from .._internal.decorators import log_get_token
1812
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase
1913

20-
try:
21-
from typing import TYPE_CHECKING
22-
except ImportError:
23-
TYPE_CHECKING = False
24-
2514
if TYPE_CHECKING:
2615
# pylint:disable=unused-import,ungrouped-imports
27-
from typing import Any, Dict, Optional
28-
from .. import AuthenticationRecord
16+
from typing import Any, Optional
17+
from azure.core.credentials import TokenCredential
2918
from .._internal import AadClientBase
3019

3120

32-
class SharedTokenCacheCredential(SharedTokenCacheBase):
21+
class SharedTokenCacheCredential(object):
3322
"""Authenticates using tokens in the local cache shared between Microsoft applications.
3423
35-
:param str username:
36-
Username (typically an email address) of the user to authenticate as. This is used when the local cache
37-
contains tokens for multiple identities.
24+
:param str username: Username (typically an email address) of the user to authenticate as. This is used when the
25+
local cache contains tokens for multiple identities.
3826
3927
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
4028
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
@@ -55,21 +43,13 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
5543
def __init__(self, username=None, **kwargs):
5644
# type: (Optional[str], **Any) -> None
5745

58-
self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
59-
if self._auth_record:
60-
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
61-
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
62-
validate_tenant_id(self._tenant_id)
63-
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
64-
self._cache = kwargs.pop("_cache", None)
65-
self._client_applications = {} # type: Dict[str, PublicClientApplication]
66-
self._msal_client = MsalClient(**kwargs)
67-
self._initialized = False
46+
if "authentication_record" in kwargs:
47+
self._credential = SilentAuthenticationCredential(**kwargs) # type: TokenCredential
6848
else:
69-
super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs)
49+
self._credential = _SharedTokenCacheCredential(username=username, **kwargs)
7050

7151
@log_get_token("SharedTokenCacheCredential")
72-
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
52+
def get_token(self, *scopes, **kwargs):
7353
# type (*str, **Any) -> AccessToken
7454
"""Get an access token for `scopes` from the shared cache.
7555
@@ -78,14 +58,34 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
7858
This method is called automatically by Azure SDK clients.
7959
8060
:param str scopes: desired scopes for the access token. This method requires at least one scope.
61+
8162
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
82-
claims challenge following an authorization failure
63+
claims challenge following an authorization failure
64+
8365
:rtype: :class:`azure.core.credentials.AccessToken`
66+
8467
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
8568
information
8669
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
87-
attribute gives a reason.
70+
attribute gives a reason.
8871
"""
72+
return self._credential.get_token(*scopes, **kwargs)
73+
74+
@staticmethod
75+
def supported():
76+
# type: () -> bool
77+
"""Whether the shared token cache is supported on the current platform.
78+
79+
:rtype: bool
80+
"""
81+
return SharedTokenCacheBase.supported()
82+
83+
84+
class _SharedTokenCacheCredential(SharedTokenCacheBase):
85+
"""The original SharedTokenCacheCredential, which doesn't use msal.ClientApplication"""
86+
87+
def get_token(self, *scopes, **kwargs):
88+
# type (*str, **Any) -> AccessToken
8989
if not scopes:
9090
raise ValueError("'get_token' requires at least one scope")
9191

@@ -95,9 +95,6 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
9595
if not self._cache:
9696
raise CredentialUnavailableError(message="Shared token cache unavailable")
9797

98-
if self._auth_record:
99-
return self._acquire_token_silent(*scopes, **kwargs)
100-
10198
account = self._get_account(self._username, self._tenant_id)
10299

103100
token = self._get_cached_access_token(scopes, account)
@@ -114,67 +111,3 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
114111
def _get_auth_client(self, **kwargs):
115112
# type: (**Any) -> AadClientBase
116113
return AadClient(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, **kwargs)
117-
118-
def _initialize(self):
119-
if self._initialized:
120-
return
121-
122-
if not self._auth_record:
123-
super(SharedTokenCacheCredential, self)._initialize()
124-
return
125-
126-
self._load_cache()
127-
self._initialized = True
128-
129-
def _get_client_application(self, **kwargs):
130-
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
131-
if tenant_id not in self._client_applications:
132-
# CP1 = can handle claims challenges (CAE)
133-
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
134-
self._client_applications[tenant_id] = PublicClientApplication(
135-
client_id=self._auth_record.client_id,
136-
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
137-
token_cache=self._cache,
138-
http_client=self._msal_client,
139-
client_capabilities=capabilities
140-
)
141-
return self._client_applications[tenant_id]
142-
143-
@wrap_exceptions
144-
def _acquire_token_silent(self, *scopes, **kwargs):
145-
# type: (*str, **Any) -> AccessToken
146-
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""
147-
148-
# this won't be None when this method is called by get_token but we check anyway to satisfy mypy
149-
if self._auth_record is None:
150-
raise CredentialUnavailableError("Initialization failed")
151-
152-
result = None
153-
154-
client_application = self._get_client_application(**kwargs)
155-
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
156-
if not accounts_for_user:
157-
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")
158-
159-
for account in accounts_for_user:
160-
if account.get("home_account_id") != self._auth_record.home_account_id:
161-
continue
162-
163-
now = int(time.time())
164-
result = client_application.acquire_token_silent_with_error(
165-
list(scopes), account=account, claims_challenge=kwargs.get("claims")
166-
)
167-
if result and "access_token" in result and "expires_in" in result:
168-
return AccessToken(result["access_token"], now + int(result["expires_in"]))
169-
170-
# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
171-
if result:
172-
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
173-
message = "Token acquisition failed"
174-
details = result.get("error_description") or result.get("error")
175-
if details:
176-
message += ": {}".format(details)
177-
raise ClientAuthenticationError(message=message)
178-
179-
# cache doesn't contain a matching refresh (or access) token
180-
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import os
6+
import platform
7+
import time
8+
from typing import TYPE_CHECKING
9+
10+
from msal import PublicClientApplication
11+
12+
from azure.core.credentials import AccessToken
13+
from azure.core.exceptions import ClientAuthenticationError
14+
15+
from .. import CredentialUnavailableError
16+
from .._internal import resolve_tenant, validate_tenant_id
17+
from .._internal.decorators import wrap_exceptions
18+
from .._internal.msal_client import MsalClient
19+
from .._internal.shared_token_cache import NO_TOKEN
20+
from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions
21+
22+
if TYPE_CHECKING:
23+
# pylint:disable=unused-import,ungrouped-imports
24+
from typing import Any, Dict
25+
from .. import AuthenticationRecord
26+
27+
28+
class SilentAuthenticationCredential(object):
29+
"""Internal class for authenticating from the default shared cache given an AuthenticationRecord"""
30+
31+
def __init__(self, authentication_record, **kwargs):
32+
# type: (AuthenticationRecord, **Any) -> None
33+
self._auth_record = authentication_record
34+
35+
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
36+
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
37+
validate_tenant_id(self._tenant_id)
38+
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
39+
self._cache = kwargs.pop("_cache", None)
40+
self._client_applications = {} # type: Dict[str, PublicClientApplication]
41+
self._client = MsalClient(**kwargs)
42+
self._initialized = False
43+
44+
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
45+
# type (*str, **Any) -> AccessToken
46+
if not scopes:
47+
raise ValueError('"get_token" requires at least one scope')
48+
49+
if not self._initialized:
50+
self._initialize()
51+
52+
if not self._cache:
53+
raise CredentialUnavailableError(message="Shared token cache unavailable")
54+
55+
return self._acquire_token_silent(*scopes, **kwargs)
56+
57+
def _initialize(self):
58+
if not self._cache and platform.system() in {"Darwin", "Linux", "Windows"}:
59+
try:
60+
# This credential accepts the user's default cache regardless of whether it's encrypted. It doesn't
61+
# create a new cache. If the default cache exists, the user must have created it earlier. If it's
62+
# unencrypted, the user must have allowed that.
63+
self._cache = _load_persistent_cache(TokenCachePersistenceOptions(allow_unencrypted_storage=True))
64+
except Exception: # pylint:disable=broad-except
65+
pass
66+
67+
self._initialized = True
68+
69+
def _get_client_application(self, **kwargs):
70+
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
71+
if tenant_id not in self._client_applications:
72+
# CP1 = can handle claims challenges (CAE)
73+
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
74+
self._client_applications[tenant_id] = PublicClientApplication(
75+
client_id=self._auth_record.client_id,
76+
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
77+
token_cache=self._cache,
78+
http_client=self._client,
79+
client_capabilities=capabilities
80+
)
81+
return self._client_applications[tenant_id]
82+
83+
@wrap_exceptions
84+
def _acquire_token_silent(self, *scopes, **kwargs):
85+
# type: (*str, **Any) -> AccessToken
86+
"""Silently acquire a token from MSAL."""
87+
88+
result = None
89+
90+
client_application = self._get_client_application(**kwargs)
91+
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
92+
if not accounts_for_user:
93+
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")
94+
95+
for account in accounts_for_user:
96+
if account.get("home_account_id") != self._auth_record.home_account_id:
97+
continue
98+
99+
now = int(time.time())
100+
result = client_application.acquire_token_silent_with_error(
101+
list(scopes), account=account, claims_challenge=kwargs.get("claims")
102+
)
103+
if result and "access_token" in result and "expires_in" in result:
104+
return AccessToken(result["access_token"], now + int(result["expires_in"]))
105+
106+
# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
107+
if result:
108+
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
109+
message = "Token acquisition failed"
110+
details = result.get("error_description") or result.get("error")
111+
if details:
112+
message += ": {}".format(details)
113+
raise ClientAuthenticationError(message=message)
114+
115+
# cache doesn't contain a matching refresh (or access) token
116+
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))

sdk/identity/azure-identity/tests/test_shared_cache_credential.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def send(request, **_):
752752
transport = Mock(send=send)
753753
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
754754

755-
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
755+
with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication:
756756
with pytest.raises(ClientAuthenticationError): # (cache is empty)
757757
credential.get_token("scope")
758758

@@ -761,7 +761,7 @@ def send(request, **_):
761761
assert kwargs["client_capabilities"] == ["CP1"]
762762

763763
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
764-
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
764+
with patch("azure.identity._credentials.silent.PublicClientApplication") as PublicClientApplication:
765765
with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
766766
with pytest.raises(ClientAuthenticationError): # (cache is empty)
767767
credential.get_token("scope")
@@ -786,7 +786,7 @@ def test_claims_challenge():
786786

787787
transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
788788
credential = SharedTokenCacheCredential(transport=transport, authentication_record=record, _cache=TokenCache())
789-
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication", lambda *_, **__: msal_app):
789+
with patch("azure.identity._credentials.silent.PublicClientApplication", lambda *_, **__: msal_app):
790790
credential.get_token("scope", claims=expected_claims)
791791

792792
assert msal_app.acquire_token_silent_with_error.call_count == 1

0 commit comments

Comments
 (0)