55 from urllib .parse import urljoin
66import logging
77import sys
8+ import warnings
9+
10+ import requests
811
912from .oauth2cli import Client , JwtSigner
1013from .authority import Authority
1518
1619
1720# The __init__.py will import this. Not the other way around.
18- __version__ = "0.2 .0"
21+ __version__ = "0.3 .0"
1922
2023logger = logging .getLogger (__name__ )
2124
@@ -101,6 +104,14 @@ def __init__(
101104 # Here the self.authority is not the same type as authority in input
102105 self .token_cache = token_cache or TokenCache ()
103106 self .client = self ._build_client (client_credential , self .authority )
107+ self .authority_groups = self ._get_authority_aliases ()
108+
109+ def _get_authority_aliases (self ):
110+ resp = requests .get (
111+ "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize" ,
112+ headers = {'Accept' : 'application/json' })
113+ resp .raise_for_status ()
114+ return [set (group ['aliases' ]) for group in resp .json ()['metadata' ]]
104115
105116 def _build_client (self , client_credential , authority ):
106117 client_assertion = None
@@ -236,19 +247,33 @@ def get_accounts(self, username=None):
236247 Your app can choose to display those information to end user,
237248 and allow user to choose one of his/her accounts to proceed.
238249 """
239- # The following implementation finds accounts only from saved accounts,
240- # but does NOT correlate them with saved RTs. It probably won't matter,
241- # because in MSAL universe, there are always Accounts and RTs together.
242- accounts = self .token_cache .find (
243- self .token_cache .CredentialType .ACCOUNT ,
244- query = {"environment" : self .authority .instance })
250+ accounts = self ._find_msal_accounts (environment = self .authority .instance )
251+ if not accounts : # Now try other aliases of this authority instance
252+ for group in self .authority_groups :
253+ if self .authority .instance in group :
254+ for alias in group :
255+ if alias != self .authority .instance :
256+ accounts = self ._find_msal_accounts (environment = alias )
257+ if accounts :
258+ break
245259 if username :
246260 # Federated account["username"] from AAD could contain mixed case
247261 lowercase_username = username .lower ()
248262 accounts = [a for a in accounts
249263 if a ["username" ].lower () == lowercase_username ]
264+ # Does not further filter by existing RTs here. It probably won't matter.
265+ # Because in most cases Accounts and RTs co-exist.
266+ # Even in the rare case when an RT is revoked and then removed,
267+ # acquire_token_silent() would then yield no result,
268+ # apps would fall back to other acquire methods. This is the standard pattern.
250269 return accounts
251270
271+ def _find_msal_accounts (self , environment ):
272+ return [a for a in self .token_cache .find (
273+ TokenCache .CredentialType .ACCOUNT , query = {"environment" : environment })
274+ if a ["authority_type" ] in (
275+ TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS )]
276+
252277 def acquire_token_silent (
253278 self ,
254279 scopes , # type: List[str]
@@ -275,19 +300,44 @@ def acquire_token_silent(
275300 - None when cache lookup does not yield anything.
276301 """
277302 assert isinstance (scopes , list ), "Invalid parameter type"
278- the_authority = Authority (
279- authority ,
280- verify = self .verify , proxies = self .proxies , timeout = self .timeout ,
281- ) if authority else self .authority
282-
303+ if authority :
304+ warnings .warn ("We haven't decided how/if this method will accept authority parameter" )
305+ # the_authority = Authority(
306+ # authority,
307+ # verify=self.verify, proxies=self.proxies, timeout=self.timeout,
308+ # ) if authority else self.authority
309+ result = self ._acquire_token_silent (scopes , account , self .authority , ** kwargs )
310+ if result :
311+ return result
312+ for group in self .authority_groups :
313+ if self .authority .instance in group :
314+ for alias in group :
315+ if alias != self .authority .instance :
316+ the_authority = Authority (
317+ "https://" + alias + "/" + self .authority .tenant ,
318+ validate_authority = False ,
319+ verify = self .verify , proxies = self .proxies ,
320+ timeout = self .timeout ,)
321+ result = self ._acquire_token_silent (
322+ scopes , account , the_authority , ** kwargs )
323+ if result :
324+ return result
325+
326+ def _acquire_token_silent (
327+ self ,
328+ scopes , # type: List[str]
329+ account , # type: Optional[Account]
330+ authority , # This can be different than self.authority
331+ force_refresh = False , # type: Optional[boolean]
332+ ** kwargs ):
283333 if not force_refresh :
284334 matches = self .token_cache .find (
285335 self .token_cache .CredentialType .ACCESS_TOKEN ,
286336 target = scopes ,
287337 query = {
288338 "client_id" : self .client_id ,
289- "environment" : the_authority .instance ,
290- "realm" : the_authority .tenant ,
339+ "environment" : authority .instance ,
340+ "realm" : authority .tenant ,
291341 "home_account_id" : (account or {}).get ("home_account_id" ),
292342 })
293343 now = time .time ()
@@ -301,26 +351,71 @@ def acquire_token_silent(
301351 "token_type" : "Bearer" ,
302352 "expires_in" : int (expires_in ), # OAuth2 specs defines it as int
303353 }
354+ return self ._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
355+ authority , decorate_scope (scopes , self .client_id ), account ,
356+ ** kwargs )
304357
358+ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
359+ self , authority , scopes , account , ** kwargs ):
360+ query = {
361+ "environment" : authority .instance ,
362+ "home_account_id" : (account or {}).get ("home_account_id" ),
363+ # "realm": authority.tenant, # AAD RTs are tenant-independent
364+ }
365+ apps = self .token_cache .find ( # Use find(), rather than token_cache.get(...)
366+ TokenCache .CredentialType .APP_METADATA , query = {
367+ "environment" : authority .instance , "client_id" : self .client_id })
368+ app_metadata = apps [0 ] if apps else {}
369+ if not app_metadata : # Meaning this app is now used for the first time.
370+ # When/if we have a way to directly detect current app's family,
371+ # we'll rewrite this block, to support multiple families.
372+ # For now, we try existing RTs (*). If it works, we are in that family.
373+ # (*) RTs of a different app/family are not supposed to be
374+ # shared with or accessible by us in the first place.
375+ at = self ._acquire_token_silent_by_finding_specific_refresh_token (
376+ authority , scopes ,
377+ dict (query , family_id = "1" ), # A hack, we have only 1 family for now
378+ rt_remover = lambda rt_item : None , # NO-OP b/c RTs are likely not mine
379+ break_condition = lambda response : # Break loop when app not in family
380+ # Based on an AAD-only behavior mentioned in internal doc here
381+ # https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595
382+ "client_mismatch" in response .get ("error_additional_info" , []),
383+ ** kwargs )
384+ if at :
385+ return at
386+ if app_metadata .get ("family_id" ): # Meaning this app belongs to this family
387+ at = self ._acquire_token_silent_by_finding_specific_refresh_token (
388+ authority , scopes , dict (query , family_id = app_metadata ["family_id" ]),
389+ ** kwargs )
390+ if at :
391+ return at
392+ # Either this app is an orphan, so we will naturally use its own RT;
393+ # or all attempts above have failed, so we fall back to non-foci behavior.
394+ return self ._acquire_token_silent_by_finding_specific_refresh_token (
395+ authority , scopes , dict (query , client_id = self .client_id ), ** kwargs )
396+
397+ def _acquire_token_silent_by_finding_specific_refresh_token (
398+ self , authority , scopes , query ,
399+ rt_remover = None , break_condition = lambda response : False , ** kwargs ):
305400 matches = self .token_cache .find (
306401 self .token_cache .CredentialType .REFRESH_TOKEN ,
307402 # target=scopes, # AAD RTs are scope-independent
308- query = {
309- "client_id" : self .client_id ,
310- "environment" : the_authority .instance ,
311- "home_account_id" : (account or {}).get ("home_account_id" ),
312- # "realm": the_authority.tenant, # AAD RTs are tenant-independent
313- })
314- client = self ._build_client (self .client_credential , the_authority )
403+ query = query )
404+ logger .debug ("Found %d RTs matching %s" , len (matches ), query )
405+ client = self ._build_client (self .client_credential , authority )
315406 for entry in matches :
316- logger .debug ("Cache hit an RT" )
407+ logger .debug ("Cache attempts an RT" )
317408 response = client .obtain_token_by_refresh_token (
318409 entry , rt_getter = lambda token_item : token_item ["secret" ],
319- scope = decorate_scope (scopes , self .client_id ))
410+ on_removing_rt = rt_remover or self .token_cache .remove_rt ,
411+ scope = scopes ,
412+ ** kwargs )
320413 if "error" not in response :
321414 return response
322415 logger .debug (
323416 "Refresh failed. {error}: {error_description}" .format (** response ))
417+ if break_condition (response ):
418+ break
324419
325420
326421class PublicClientApplication (ClientApplication ): # browser app or mobile app
0 commit comments