22# All rights reserved.
33#
44# This code is licensed under the MIT License.
5+ import hashlib
56import json
67import logging
78import os
1011import time
1112from urllib .parse import urlparse # Python 3+
1213from collections import UserDict # Python 3+
13- from typing import Optional , Union # Needed in Python 3.7 & 3.8
14+ from typing import List , Optional , Union # Needed in Python 3.7 & 3.8
1415from .token_cache import TokenCache
1516from .individual_cache import _IndividualCache as IndividualCache
1617from .throttled_http_client import ThrottledHttpClientBase , RetryAfterParser
@@ -162,6 +163,7 @@ def __init__(
162163 http_client ,
163164 token_cache = None ,
164165 http_cache = None ,
166+ client_capabilities : Optional [List [str ]] = None ,
165167 ):
166168 """Create a managed identity client.
167169
@@ -192,6 +194,17 @@ def __init__(
192194 Optional. It has the same characteristics as the
193195 :paramref:`msal.ClientApplication.http_cache`.
194196
197+ :param list[str] client_capabilities: (optional)
198+ Allows configuration of one or more client capabilities, e.g. ["CP1"].
199+
200+ Client capability is meant to inform the Microsoft identity platform
201+ (STS) what this client is capable for,
202+ so STS can decide to turn on certain features.
203+
204+ Implementation details:
205+ Client capability in Managed Identity is relayed as-is
206+ via ``xms_cc`` parameter on the wire.
207+
195208 Recipe 1: Hard code a managed identity for your app::
196209
197210 import msal, requests
@@ -238,6 +251,7 @@ def __init__(
238251 http_cache = http_cache ,
239252 )
240253 self ._token_cache = token_cache or TokenCache ()
254+ self ._client_capabilities = client_capabilities
241255
242256 def _get_instance (self ):
243257 if self .__instance is None :
@@ -266,8 +280,7 @@ def acquire_token_for_client(
266280 and then a *claims challenge* will be returned by the target resource,
267281 as a `claims_challenge` directive in the `www-authenticate` header,
268282 even if the app developer did not opt in for the "CP1" client capability.
269- Upon receiving a `claims_challenge`, MSAL will skip a token cache read,
270- and will attempt to acquire a new token.
283+ Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token.
271284
272285 .. note::
273286
@@ -278,12 +291,14 @@ def acquire_token_for_client(
278291 This is a service-side behavior that cannot be changed by this library.
279292 `Azure VM docs <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>`_
280293 """
294+ access_token_to_refresh = None # This could become a public parameter in the future
281295 access_token_from_cache = None
282296 client_id_in_cache = self ._managed_identity .get (
283297 ManagedIdentity .ID , "SYSTEM_ASSIGNED_MANAGED_IDENTITY" )
284298 now = time .time ()
285- if not claims_challenge : # Then attempt token cache search
286- matches = self ._token_cache .find (
299+ if True : # Attempt cache search even if receiving claims_challenge,
300+ # because we want to locate the existing token (if any) and refresh it
301+ matches = self ._token_cache .search (
287302 self ._token_cache .CredentialType .ACCESS_TOKEN ,
288303 target = [resource ],
289304 query = dict (
@@ -297,6 +312,11 @@ def acquire_token_for_client(
297312 expires_in = int (entry ["expires_on" ]) - now
298313 if expires_in < 5 * 60 : # Then consider it expired
299314 continue # Removal is not necessary, it will be overwritten
315+ if claims_challenge and not access_token_to_refresh :
316+ # Since caller did not pinpoint the token causing claims challenge,
317+ # we have to assume it is the first token we found in cache.
318+ access_token_to_refresh = entry ["secret" ]
319+ break
300320 logger .debug ("Cache hit an AT" )
301321 access_token_from_cache = { # Mimic a real response
302322 "access_token" : entry ["secret" ],
@@ -310,7 +330,13 @@ def acquire_token_for_client(
310330 break # With a fallback in hand, we break here to go refresh
311331 return access_token_from_cache # It is still good as new
312332 try :
313- result = _obtain_token (self ._http_client , self ._managed_identity , resource )
333+ result = _obtain_token (
334+ self ._http_client , self ._managed_identity , resource ,
335+ access_token_sha256_to_refresh = hashlib .sha256 (
336+ access_token_to_refresh .encode ("utf-8" )).hexdigest ()
337+ if access_token_to_refresh else None ,
338+ client_capabilities = self ._client_capabilities ,
339+ )
314340 if "access_token" in result :
315341 expires_in = result .get ("expires_in" , 3600 )
316342 if "refresh_in" not in result and expires_in >= 7200 :
@@ -385,8 +411,12 @@ def get_managed_identity_source():
385411 return DEFAULT_TO_VM
386412
387413
388- def _obtain_token (http_client , managed_identity , resource ):
389- # A unified low-level API that talks to different Managed Identity
414+ def _obtain_token (
415+ http_client , managed_identity , resource ,
416+ * ,
417+ access_token_sha256_to_refresh : Optional [str ] = None ,
418+ client_capabilities : Optional [List [str ]] = None ,
419+ ):
390420 if ("IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ
391421 and "IDENTITY_SERVER_THUMBPRINT" in os .environ
392422 ):
@@ -402,6 +432,8 @@ def _obtain_token(http_client, managed_identity, resource):
402432 os .environ ["IDENTITY_HEADER" ],
403433 os .environ ["IDENTITY_SERVER_THUMBPRINT" ],
404434 resource ,
435+ access_token_sha256_to_refresh = access_token_sha256_to_refresh ,
436+ client_capabilities = client_capabilities ,
405437 )
406438 if "IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ :
407439 return _obtain_token_on_app_service (
@@ -553,6 +585,9 @@ def _obtain_token_on_machine_learning(
553585
554586def _obtain_token_on_service_fabric (
555587 http_client , endpoint , identity_header , server_thumbprint , resource ,
588+ * ,
589+ access_token_sha256_to_refresh : str = None ,
590+ client_capabilities : Optional [List [str ]] = None ,
556591):
557592 """Obtains token for
558593 `Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
@@ -563,7 +598,12 @@ def _obtain_token_on_service_fabric(
563598 logger .debug ("Obtaining token via managed identity on Azure Service Fabric" )
564599 resp = http_client .get (
565600 endpoint ,
566- params = {"api-version" : "2019-07-01-preview" , "resource" : resource },
601+ params = {k : v for k , v in {
602+ "api-version" : "2019-07-01-preview" ,
603+ "resource" : resource ,
604+ "token_sha256_to_refresh" : access_token_sha256_to_refresh ,
605+ "xms_cc" : "," .join (client_capabilities ) if client_capabilities else None ,
606+ }.items () if v is not None },
567607 headers = {"Secret" : identity_header },
568608 )
569609 try :
@@ -584,7 +624,7 @@ def _obtain_token_on_service_fabric(
584624 "ArgumentNullOrEmpty" : "invalid_scope" ,
585625 }
586626 return {
587- "error" : error_mapping .get (payload [ " error" ][ " code"] , "invalid_request" ),
627+ "error" : error_mapping .get (error . get ( " code") , "invalid_request" ),
588628 "error_description" : resp .text ,
589629 }
590630 except json .decoder .JSONDecodeError :
0 commit comments