1919from .wstrust_response import *
2020from .token_cache import TokenCache
2121import msal .telemetry
22+ from .region import _detect_region
2223
2324
2425# The __init__.py will import this. Not the other way around.
@@ -108,14 +109,21 @@ class ClientApplication(object):
108109 GET_ACCOUNTS_ID = "902"
109110 REMOVE_ACCOUNT_ID = "903"
110111
112+ ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect"
113+
111114 def __init__ (
112115 self , client_id ,
113116 client_credential = None , authority = None , validate_authority = True ,
114117 token_cache = None ,
115118 http_client = None ,
116119 verify = True , proxies = None , timeout = None ,
117120 client_claims = None , app_name = None , app_version = None ,
118- client_capabilities = None ):
121+ client_capabilities = None ,
122+ azure_region = None , # Note: We choose to add this param in this base class,
123+ # despite it is currently only needed by ConfidentialClientApplication.
124+ # This way, it holds the same positional param place for PCA,
125+ # when we would eventually want to add this feature to PCA in future.
126+ ):
119127 """Create an instance of application.
120128
121129 :param str client_id: Your app has a client_id after you register it on AAD.
@@ -220,6 +228,53 @@ def __init__(
220228 MSAL will combine them into
221229 `claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
222230 which you will later provide via one of the acquire-token request.
231+
232+ :param str azure_region:
233+ Added since MSAL Python 1.12.0.
234+
235+ As of 2021 May, regional service is only available for
236+ ``acquire_token_for_client()`` sent by any of the following scenarios::
237+
238+ 1. An app powered by a capable MSAL
239+ (MSAL Python 1.12+ will be provisioned)
240+
241+ 2. An app with managed identity, which is formerly known as MSI.
242+ (However MSAL Python does not support managed identity,
243+ so this one does not apply.)
244+
245+ 3. An app authenticated by
246+ `Subject Name/Issuer (SNI) <https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/60>`_.
247+
248+ 4. An app which already onboard to the region's allow-list.
249+
250+ MSAL's default value is None, which means region behavior remains off.
251+ If enabled, the `acquire_token_for_client()`-relevant traffic
252+ would remain inside that region.
253+
254+ App developer can opt in to a regional endpoint,
255+ by provide its region name, such as "westus", "eastus2".
256+ You can find a full list of regions by running
257+ ``az account list-locations -o table``, or referencing to
258+ `this doc <https://docs.microsoft.com/en-us/dotnet/api/microsoft.azure.management.resourcemanager.fluent.core.region?view=azure-dotnet>`_.
259+
260+ An app running inside Azure Functions and Azure VM can use a special keyword
261+ ``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.
262+
263+ .. note::
264+
265+ Setting ``azure_region`` to non-``None`` for an app running
266+ outside of Azure Function/VM could hang indefinitely.
267+
268+ You should consider opting in/out region behavior on-demand,
269+ by loading ``azure_region=None`` or ``azure_region="westus"``
270+ or ``azure_region=True`` (which means opt-in and auto-detect)
271+ from your per-deployment configuration, and then do
272+ ``app = ConfidentialClientApplication(..., azure_region=azure_region)``.
273+
274+ Alternatively, you can configure a short timeout,
275+ or provide a custom http_client which has a short timeout.
276+ That way, the latency would be under your control,
277+ but still less performant than opting out of region feature.
223278 """
224279 self .client_id = client_id
225280 self .client_credential = client_credential
@@ -244,12 +299,29 @@ def __init__(
244299
245300 self .app_name = app_name
246301 self .app_version = app_version
247- self .authority = Authority (
302+
303+ # Here the self.authority will not be the same type as authority in input
304+ try :
305+ self .authority = Authority (
248306 authority or "https://login.microsoftonline.com/common/" ,
249307 self .http_client , validate_authority = validate_authority )
250- # Here the self.authority is not the same type as authority in input
308+ except ValueError : # Those are explicit authority validation errors
309+ raise
310+ except Exception : # The rest are typically connection errors
311+ if validate_authority and region :
312+ # Since caller opts in to use region, here we tolerate connection
313+ # errors happened during authority validation at non-region endpoint
314+ self .authority = Authority (
315+ authority or "https://login.microsoftonline.com/common/" ,
316+ self .http_client , validate_authority = False )
317+ else :
318+ raise
319+
251320 self .token_cache = token_cache or TokenCache ()
252- self .client = self ._build_client (client_credential , self .authority )
321+ self ._region_configured = azure_region
322+ self ._region_detected = None
323+ self .client , self ._regional_client = self ._build_client (
324+ client_credential , self .authority )
253325 self .authority_groups = None
254326 self ._telemetry_buffer = {}
255327 self ._telemetry_lock = Lock ()
@@ -260,6 +332,32 @@ def _build_telemetry_context(
260332 self ._telemetry_buffer , self ._telemetry_lock , api_id ,
261333 correlation_id = correlation_id , refresh_reason = refresh_reason )
262334
335+ def _get_regional_authority (self , central_authority ):
336+ is_region_specified = bool (self ._region_configured
337+ and self ._region_configured != self .ATTEMPT_REGION_DISCOVERY )
338+ self ._region_detected = self ._region_detected or _detect_region (
339+ self .http_client if self ._region_configured is not None else None )
340+ if (is_region_specified and self ._region_configured != self ._region_detected ):
341+ logger .warning ('Region configured ({}) != region detected ({})' .format (
342+ repr (self ._region_configured ), repr (self ._region_detected )))
343+ region_to_use = (
344+ self ._region_configured if is_region_specified else self ._region_detected )
345+ if region_to_use :
346+ logger .info ('Region to be used: {}' .format (repr (region_to_use )))
347+ regional_host = ("{}.login.microsoft.com" .format (region_to_use )
348+ if central_authority .instance in (
349+ # The list came from https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/358/files#r629400328
350+ "login.microsoftonline.com" ,
351+ "login.windows.net" ,
352+ "sts.windows.net" ,
353+ )
354+ else "{}.{}" .format (region_to_use , central_authority .instance ))
355+ return Authority (
356+ "https://{}/{}" .format (regional_host , central_authority .tenant ),
357+ self .http_client ,
358+ validate_authority = False ) # The central_authority has already been validated
359+ return None
360+
263361 def _build_client (self , client_credential , authority ):
264362 client_assertion = None
265363 client_assertion_type = None
@@ -298,15 +396,15 @@ def _build_client(self, client_credential, authority):
298396 client_assertion_type = Client .CLIENT_ASSERTION_TYPE_JWT
299397 else :
300398 default_body ['client_secret' ] = client_credential
301- server_configuration = {
399+ central_configuration = {
302400 "authorization_endpoint" : authority .authorization_endpoint ,
303401 "token_endpoint" : authority .token_endpoint ,
304402 "device_authorization_endpoint" :
305403 authority .device_authorization_endpoint or
306404 urljoin (authority .token_endpoint , "devicecode" ),
307405 }
308- return Client (
309- server_configuration ,
406+ central_client = Client (
407+ central_configuration ,
310408 self .client_id ,
311409 http_client = self .http_client ,
312410 default_headers = default_headers ,
@@ -318,6 +416,31 @@ def _build_client(self, client_credential, authority):
318416 on_removing_rt = self .token_cache .remove_rt ,
319417 on_updating_rt = self .token_cache .update_rt )
320418
419+ regional_client = None
420+ if client_credential : # Currently regional endpoint only serves some CCA flows
421+ regional_authority = self ._get_regional_authority (authority )
422+ if regional_authority :
423+ regional_configuration = {
424+ "authorization_endpoint" : regional_authority .authorization_endpoint ,
425+ "token_endpoint" : regional_authority .token_endpoint ,
426+ "device_authorization_endpoint" :
427+ regional_authority .device_authorization_endpoint or
428+ urljoin (regional_authority .token_endpoint , "devicecode" ),
429+ }
430+ regional_client = Client (
431+ regional_configuration ,
432+ self .client_id ,
433+ http_client = self .http_client ,
434+ default_headers = default_headers ,
435+ default_body = default_body ,
436+ client_assertion = client_assertion ,
437+ client_assertion_type = client_assertion_type ,
438+ on_obtaining_tokens = lambda event : self .token_cache .add (dict (
439+ event , environment = authority .instance )),
440+ on_removing_rt = self .token_cache .remove_rt ,
441+ on_updating_rt = self .token_cache .update_rt )
442+ return central_client , regional_client
443+
321444 def initiate_auth_code_flow (
322445 self ,
323446 scopes , # type: list[str]
@@ -953,7 +1076,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
9531076 # target=scopes, # AAD RTs are scope-independent
9541077 query = query )
9551078 logger .debug ("Found %d RTs matching %s" , len (matches ), query )
956- client = self ._build_client (self .client_credential , authority )
1079+ client , _ = self ._build_client (self .client_credential , authority )
9571080
9581081 response = None # A distinguishable value to mean cache is empty
9591082 telemetry_context = self ._build_telemetry_context (
@@ -1304,7 +1427,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
13041427 self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
13051428 telemetry_context = self ._build_telemetry_context (
13061429 self .ACQUIRE_TOKEN_FOR_CLIENT_ID )
1307- response = _clean_up (self .client .obtain_token_for_client (
1430+ client = self ._regional_client or self .client
1431+ response = _clean_up (client .obtain_token_for_client (
13081432 scope = scopes , # This grant flow requires no scope decoration
13091433 headers = telemetry_context .generate_headers (),
13101434 data = dict (
0 commit comments