22# All rights reserved.
33#
44# This code is licensed under the MIT License.
5+ import hashlib
56import json
67import logging
78import os
@@ -266,8 +267,7 @@ def acquire_token_for_client(
266267 and then a *claims challenge* will be returned by the target resource,
267268 as a `claims_challenge` directive in the `www-authenticate` header,
268269 even if the app developer did not opt in for the "CP1" client capability.
269- Upon receiving a `claims_challenge`, MSAL will skip a token cache read,
270- and will attempt to acquire a new token.
270+ Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.
271271
272272 .. note::
273273
@@ -278,11 +278,13 @@ def acquire_token_for_client(
278278 This is a service-side behavior that cannot be changed by this library.
279279 `Azure VM docs <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>`_
280280 """
281+ access_token_to_refresh = None # This could become a public parameter in the future
281282 access_token_from_cache = None
282283 client_id_in_cache = self ._managed_identity .get (
283284 ManagedIdentity .ID , "SYSTEM_ASSIGNED_MANAGED_IDENTITY" )
284285 now = time .time ()
285- if not claims_challenge : # Then attempt token cache search
286+ if True : # Attempt cache search even if receiving claims_challenge,
287+ # because we want to locate the existing token (if any) and refresh it
286288 matches = self ._token_cache .find (
287289 self ._token_cache .CredentialType .ACCESS_TOKEN ,
288290 target = [resource ],
@@ -297,6 +299,11 @@ def acquire_token_for_client(
297299 expires_in = int (entry ["expires_on" ]) - now
298300 if expires_in < 5 * 60 : # Then consider it expired
299301 continue # Removal is not necessary, it will be overwritten
302+ if claims_challenge and not access_token_to_refresh :
303+ # Since caller did not pinpoint the token causing claims challenge,
304+ # we have to assume it is the first token we found in cache.
305+ access_token_to_refresh = entry ["secret" ]
306+ break
300307 logger .debug ("Cache hit an AT" )
301308 access_token_from_cache = { # Mimic a real response
302309 "access_token" : entry ["secret" ],
@@ -310,7 +317,13 @@ def acquire_token_for_client(
310317 break # With a fallback in hand, we break here to go refresh
311318 return access_token_from_cache # It is still good as new
312319 try :
313- result = _obtain_token (self ._http_client , self ._managed_identity , resource )
320+ result = _obtain_token (
321+ self ._http_client , self ._managed_identity , resource ,
322+ claims_challenge = claims_challenge ,
323+ access_token_sha256_to_refresh = hashlib .sha256 (
324+ access_token_to_refresh .encode ("utf-8" )).hexdigest ()
325+ if access_token_to_refresh else None ,
326+ )
314327 if "access_token" in result :
315328 expires_in = result .get ("expires_in" , 3600 )
316329 if "refresh_in" not in result and expires_in >= 7200 :
@@ -385,8 +398,14 @@ def get_managed_identity_source():
385398 return DEFAULT_TO_VM
386399
387400
388- def _obtain_token (http_client , managed_identity , resource ):
389- # A unified low-level API that talks to different Managed Identity
401+ def _obtain_token (
402+ http_client , managed_identity , resource ,
403+ * ,
404+ claims_challenge : Optional [str ] = None ,
405+ access_token_sha256_to_refresh : Optional [str ] = None ,
406+ ):
407+ if claims_challenge and len (claims_challenge ) > 1024 :
408+ claims_challenge = None # MSIv1 does not support long url
390409 if ("IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ
391410 and "IDENTITY_SERVER_THUMBPRINT" in os .environ
392411 ):
@@ -402,6 +421,8 @@ def _obtain_token(http_client, managed_identity, resource):
402421 os .environ ["IDENTITY_HEADER" ],
403422 os .environ ["IDENTITY_SERVER_THUMBPRINT" ],
404423 resource ,
424+ claims_challenge = claims_challenge ,
425+ access_token_sha256_to_refresh = access_token_sha256_to_refresh ,
405426 )
406427 if "IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ :
407428 return _obtain_token_on_app_service (
@@ -551,6 +572,9 @@ def _obtain_token_on_machine_learning(
551572
552573def _obtain_token_on_service_fabric (
553574 http_client , endpoint , identity_header , server_thumbprint , resource ,
575+ * ,
576+ claims_challenge : str = None ,
577+ access_token_sha256_to_refresh : str = None ,
554578):
555579 """Obtains token for
556580 `Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
@@ -561,7 +585,12 @@ def _obtain_token_on_service_fabric(
561585 logger .debug ("Obtaining token via managed identity on Azure Service Fabric" )
562586 resp = http_client .get (
563587 endpoint ,
564- params = {"api-version" : "2019-07-01-preview" , "resource" : resource },
588+ params = {k : v for k , v in {
589+ "api-version" : "2019-07-01-preview" ,
590+ "resource" : resource ,
591+ "claims" : claims_challenge ,
592+ "token_sha256_to_refresh" : access_token_sha256_to_refresh ,
593+ }.items () if v is not None },
565594 headers = {"Secret" : identity_header },
566595 )
567596 try :
0 commit comments