55 from urllib .parse import urljoin
66import logging
77import sys
8+ import warnings
9+
10+ import requests
811
912from .oauth2cli import Client , JwtSigner
1013from .authority import Authority
@@ -101,6 +104,14 @@ def __init__(
101104 # Here the self.authority is not the same type as authority in input
102105 self .token_cache = token_cache or TokenCache ()
103106 self .client = self ._build_client (client_credential , self .authority )
107+ self .authority_groups = self ._get_authority_aliases ()
108+
109+ def _get_authority_aliases (self ):
110+ resp = requests .get (
111+ "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize" ,
112+ headers = {'Accept' : 'application/json' })
113+ resp .raise_for_status ()
114+ return [set (group ['aliases' ]) for group in resp .json ()['metadata' ]]
104115
105116 def _build_client (self , client_credential , authority ):
106117 client_assertion = None
@@ -236,11 +247,15 @@ def get_accounts(self, username=None):
236247 Your app can choose to display those information to end user,
237248 and allow user to choose one of his/her accounts to proceed.
238249 """
239- accounts = [a for a in self .token_cache .find ( # Find all useful accounts
240- self .token_cache .CredentialType .ACCOUNT ,
241- query = {"environment" : self .authority .instance })
242- if a ["authority_type" ] in (
243- TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS )]
250+ accounts = self ._find_msal_accounts (environment = self .authority .instance )
251+ if not accounts : # Now try other aliases of this authority instance
252+ for group in self .authority_groups :
253+ if self .authority .instance in group :
254+ for alias in group :
255+ if alias != self .authority .instance :
256+ accounts = self ._find_msal_accounts (environment = alias )
257+ if accounts :
258+ break
244259 if username :
245260 # Federated account["username"] from AAD could contain mixed case
246261 lowercase_username = username .lower ()
@@ -253,6 +268,12 @@ def get_accounts(self, username=None):
253268 # apps would fall back to other acquire methods. This is the standard pattern.
254269 return accounts
255270
271+ def _find_msal_accounts (self , environment ):
272+ return [a for a in self .token_cache .find (
273+ TokenCache .CredentialType .ACCOUNT , query = {"environment" : environment })
274+ if a ["authority_type" ] in (
275+ TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS )]
276+
256277 def acquire_token_silent (
257278 self ,
258279 scopes , # type: List[str]
@@ -279,19 +300,44 @@ def acquire_token_silent(
279300 - None when cache lookup does not yield anything.
280301 """
281302 assert isinstance (scopes , list ), "Invalid parameter type"
282- the_authority = Authority (
283- authority ,
284- verify = self .verify , proxies = self .proxies , timeout = self .timeout ,
285- ) if authority else self .authority
286-
303+ if authority :
304+ warnings .warn ("We haven't decided how/if this method will accept authority parameter" )
305+ # the_authority = Authority(
306+ # authority,
307+ # verify=self.verify, proxies=self.proxies, timeout=self.timeout,
308+ # ) if authority else self.authority
309+ result = self ._acquire_token_silent (scopes , account , self .authority , ** kwargs )
310+ if result :
311+ return result
312+ for group in self .authority_groups :
313+ if self .authority .instance in group :
314+ for alias in group :
315+ if alias != self .authority .instance :
316+ the_authority = Authority (
317+ "https://" + alias + "/" + self .authority .tenant ,
318+ validate_authority = False ,
319+ verify = self .verify , proxies = self .proxies ,
320+ timeout = self .timeout ,)
321+ result = self ._acquire_token_silent (
322+ scopes , account , the_authority , ** kwargs )
323+ if result :
324+ return result
325+
326+ def _acquire_token_silent (
327+ self ,
328+ scopes , # type: List[str]
329+ account , # type: Optional[Account]
330+ authority , # This can be different than self.authority
331+ force_refresh = False , # type: Optional[boolean]
332+ ** kwargs ):
287333 if not force_refresh :
288334 matches = self .token_cache .find (
289335 self .token_cache .CredentialType .ACCESS_TOKEN ,
290336 target = scopes ,
291337 query = {
292338 "client_id" : self .client_id ,
293- "environment" : the_authority .instance ,
294- "realm" : the_authority .tenant ,
339+ "environment" : authority .instance ,
340+ "realm" : authority .tenant ,
295341 "home_account_id" : (account or {}).get ("home_account_id" ),
296342 })
297343 now = time .time ()
@@ -306,7 +352,7 @@ def acquire_token_silent(
306352 "expires_in" : int (expires_in ), # OAuth2 specs defines it as int
307353 }
308354 return self ._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
309- the_authority , decorate_scope (scopes , self .client_id ), account ,
355+ authority , decorate_scope (scopes , self .client_id ), account ,
310356 ** kwargs )
311357
312358 def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
0 commit comments