1616 from collections import UserDict
1717except :
1818 UserDict = dict # The real UserDict is an old-style class which fails super()
19+ from .token_cache import TokenCache
1920
2021
2122logger = logging .getLogger (__name__ )
@@ -104,6 +105,7 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
104105
105106
106107def _obtain_token (http_client , managed_identity , resource ):
108+ # A unified low-level API that talks to different Managed Identity
107109 if ("IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ
108110 and "IDENTITY_SERVER_THUMBPRINT" in os .environ
109111 ):
@@ -303,6 +305,12 @@ def _obtain_token_on_arc(http_client, endpoint, resource):
303305
304306
305307class ManagedIdentityClient (object ):
308+ """A low level API that encapulate multiple managed identity backends:
309+ VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric,
310+ and Azure Arc.
311+
312+ It also provides token cache support.
313+ """
306314 _instance , _tenant = socket .getfqdn (), "managed_identity" # Placeholders
307315
308316 def __init__ (self , http_client , managed_identity , token_cache = None ):
@@ -319,16 +327,17 @@ def __init__(self, http_client, managed_identity, token_cache=None):
319327
320328 :param token_cache:
321329 Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
330+ It will use an in-memory token cache by default.
322331
323- Example : Hard code a managed identity for your app::
332+ Recipe 1 : Hard code a managed identity for your app::
324333
325334 import msal, requests
326335 client = msal.ManagedIdentityClient(
327336 requests.Session(),
328337 msal.UserAssignedManagedIdentity(client_id="foo"),
329338 )
330339
331- Recipe: Write once, run everywhere.
340+ Recipe 2 : Write once, run everywhere.
332341 If you use different managed identity on different deployment,
333342 you may use an environment variable (such as AZURE_MANAGED_IDENTITY)
334343 to store a json blob like
@@ -346,7 +355,7 @@ def __init__(self, http_client, managed_identity, token_cache=None):
346355 """
347356 self ._http_client = http_client
348357 self ._managed_identity = managed_identity
349- self ._token_cache = token_cache
358+ self ._token_cache = token_cache or TokenCache ()
350359
351360 def acquire_token (self , resource = None ):
352361 """Acquire token for the managed identity.
@@ -361,7 +370,8 @@ def acquire_token(self, resource=None):
361370 access_token_from_cache = None
362371 client_id_in_cache = self ._managed_identity .get (
363372 ManagedIdentity .ID , "SYSTEM_ASSIGNED_MANAGED_IDENTITY" )
364- if self ._token_cache :
373+ if True : # Does not offer an "if not force_refresh" option, because
374+ # there would be built-in token cache in the service side anyway
365375 matches = self ._token_cache .find (
366376 self ._token_cache .CredentialType .ACCESS_TOKEN ,
367377 target = [resource ],
@@ -386,17 +396,26 @@ def acquire_token(self, resource=None):
386396 if "refresh_on" in entry and int (entry ["refresh_on" ]) < now : # aging
387397 break # With a fallback in hand, we break here to go refresh
388398 return access_token_from_cache # It is still good as new
389- result = _obtain_token (self ._http_client , self ._managed_identity , resource )
390- if self ._token_cache and "access_token" in result :
391- self ._token_cache .add (dict (
392- client_id = client_id_in_cache ,
393- scope = [resource ],
394- token_endpoint = "https://{}/{}" .format (self ._instance , self ._tenant ),
395- response = result ,
396- params = {},
397- data = {},
398- #grant_type="placeholder",
399- ))
400- return result
401- return access_token_from_cache or result
399+ try :
400+ result = _obtain_token (self ._http_client , self ._managed_identity , resource )
401+ if "access_token" in result :
402+ expires_in = result .get ("expires_in" , 3600 )
403+ if "refresh_in" not in result and expires_in >= 7200 :
404+ result ["refresh_in" ] = int (expires_in / 2 )
405+ self ._token_cache .add (dict (
406+ client_id = client_id_in_cache ,
407+ scope = [resource ],
408+ token_endpoint = "https://{}/{}" .format (self ._instance , self ._tenant ),
409+ response = result ,
410+ params = {},
411+ data = {},
412+ #grant_type="placeholder",
413+ ))
414+ if (result and "error" not in result ) or (not access_token_from_cache ):
415+ return result
416+ except : # The exact HTTP exception is transportation-layer dependent
417+ # Typically network error. Potential AAD outage?
418+ if not access_token_from_cache : # It means there is no fall back option
419+ raise # We choose to bubble up the exception
420+ return access_token_from_cache
402421
0 commit comments