99import sys
1010import warnings
1111from threading import Lock
12+ import os
1213
1314import requests
1415
@@ -108,14 +109,21 @@ class ClientApplication(object):
108109 GET_ACCOUNTS_ID = "902"
109110 REMOVE_ACCOUNT_ID = "903"
110111
112+ ATTEMPT_REGION_DISCOVERY = "TryAutoDetect"
113+
111114 def __init__ (
112115 self , client_id ,
113116 client_credential = None , authority = None , validate_authority = True ,
114117 token_cache = None ,
115118 http_client = None ,
116119 verify = True , proxies = None , timeout = None ,
117120 client_claims = None , app_name = None , app_version = None ,
118- client_capabilities = None ):
121+ client_capabilities = None ,
122+ region = None , # Note: We choose to add this param in this base class,
123+ # despite it is currently only needed by ConfidentialClientApplication.
124+ # This way, it holds the same positional param place for PCA,
125+ # when we would eventually want to add this feature to PCA in future.
126+ ):
119127 """Create an instance of application.
120128
121129 :param str client_id: Your app has a client_id after you register it on AAD.
@@ -220,6 +228,25 @@ def __init__(
220228 MSAL will combine them into
221229 `claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
222230 which you will later provide via one of the acquire-token request.
231+
232+ :param str region:
233+ Added since MSAL Python 1.12.0.
234+
235+ If enabled, MSAL token requests would remain inside that region.
236+ Currently, regional endpoint only supports using
237+ ``acquire_token_for_client()`` for some scopes.
238+
239+ The default value is None, which means region support remains turned off.
240+
241+ App developer can opt in to regional endpoint,
242+ by provide a region name, such as "westus", "eastus2".
243+
244+ An app running inside Azure VM can use a special keyword
245+ ``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.
246+ (Attempting this on a non-VM could hang indefinitely.
247+ Make sure you configure a short timeout,
248+ or provide a custom http_client which has a short timeout.
249+ That way, the latency would be under your control.)
223250 """
224251 self .client_id = client_id
225252 self .client_credential = client_credential
@@ -249,7 +276,10 @@ def __init__(
249276 self .http_client , validate_authority = validate_authority )
250277 # Here the self.authority is not the same type as authority in input
251278 self .token_cache = token_cache or TokenCache ()
252- self .client = self ._build_client (client_credential , self .authority )
279+ self ._region_configured = region
280+ self ._region_detected = None
281+ self .client , self ._regional_client = self ._build_client (
282+ client_credential , self .authority )
253283 self .authority_groups = None
254284 self ._telemetry_buffer = {}
255285 self ._telemetry_lock = Lock ()
@@ -260,6 +290,26 @@ def _build_telemetry_context(
260290 self ._telemetry_buffer , self ._telemetry_lock , api_id ,
261291 correlation_id = correlation_id , refresh_reason = refresh_reason )
262292
293+ def _detect_region (self ):
294+ return os .environ .get ("REGION_NAME" ) # TODO: or Call IMDS
295+
296+ def _get_regional_authority (self , central_authority ):
297+ self ._region_detected = self ._region_detected or self ._detect_region ()
298+ if self ._region_configured and self ._region_detected != self ._region_configured :
299+ logger .warning ('Region configured ({}) != region detected ({})' .format (
300+ repr (self ._region_configured ), repr (self ._region_detected )))
301+ region_to_use = self ._region_configured or self ._region_detected
302+ if region_to_use :
303+ logger .info ('Region to be used: {}' .format (repr (region_to_use )))
304+ regional_host = ("{}.login.microsoft.com" .format (region_to_use )
305+ if central_authority .instance == "login.microsoftonline.com"
306+ else "{}.{}" .format (region_to_use , central_authority .instance ))
307+ return Authority (
308+ "https://{}/{}" .format (regional_host , central_authority .tenant ),
309+ self .http_client ,
310+ validate_authority = False ) # The central_authority has already been validated
311+ return None
312+
263313 def _build_client (self , client_credential , authority ):
264314 client_assertion = None
265315 client_assertion_type = None
@@ -298,15 +348,15 @@ def _build_client(self, client_credential, authority):
298348 client_assertion_type = Client .CLIENT_ASSERTION_TYPE_JWT
299349 else :
300350 default_body ['client_secret' ] = client_credential
301- server_configuration = {
351+ central_configuration = {
302352 "authorization_endpoint" : authority .authorization_endpoint ,
303353 "token_endpoint" : authority .token_endpoint ,
304354 "device_authorization_endpoint" :
305355 authority .device_authorization_endpoint or
306356 urljoin (authority .token_endpoint , "devicecode" ),
307357 }
308- return Client (
309- server_configuration ,
358+ central_client = Client (
359+ central_configuration ,
310360 self .client_id ,
311361 http_client = self .http_client ,
312362 default_headers = default_headers ,
@@ -318,6 +368,30 @@ def _build_client(self, client_credential, authority):
318368 on_removing_rt = self .token_cache .remove_rt ,
319369 on_updating_rt = self .token_cache .update_rt )
320370
371+ regional_client = None
372+ regional_authority = self ._get_regional_authority (authority )
373+ if regional_authority :
374+ regional_configuration = {
375+ "authorization_endpoint" : regional_authority .authorization_endpoint ,
376+ "token_endpoint" : regional_authority .token_endpoint ,
377+ "device_authorization_endpoint" :
378+ regional_authority .device_authorization_endpoint or
379+ urljoin (regional_authority .token_endpoint , "devicecode" ),
380+ }
381+ regional_client = Client (
382+ regional_configuration ,
383+ self .client_id ,
384+ http_client = self .http_client ,
385+ default_headers = default_headers ,
386+ default_body = default_body ,
387+ client_assertion = client_assertion ,
388+ client_assertion_type = client_assertion_type ,
389+ on_obtaining_tokens = lambda event : self .token_cache .add (dict (
390+ event , environment = authority .instance )),
391+ on_removing_rt = self .token_cache .remove_rt ,
392+ on_updating_rt = self .token_cache .update_rt )
393+ return central_client , regional_client
394+
321395 def initiate_auth_code_flow (
322396 self ,
323397 scopes , # type: list[str]
@@ -953,7 +1027,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
9531027 # target=scopes, # AAD RTs are scope-independent
9541028 query = query )
9551029 logger .debug ("Found %d RTs matching %s" , len (matches ), query )
956- client = self ._build_client (self .client_credential , authority )
1030+ client , _ = self ._build_client (self .client_credential , authority )
9571031
9581032 response = None # A distinguishable value to mean cache is empty
9591033 telemetry_context = self ._build_telemetry_context (
@@ -1304,7 +1378,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
13041378 self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
13051379 telemetry_context = self ._build_telemetry_context (
13061380 self .ACQUIRE_TOKEN_FOR_CLIENT_ID )
1307- response = _clean_up (self .client .obtain_token_for_client (
1381+ client = self ._regional_client or self .client
1382+ response = _clean_up (client .obtain_token_for_client (
13081383 scope = scopes , # This grant flow requires no scope decoration
13091384 headers = telemetry_context .generate_headers (),
13101385 data = dict (
0 commit comments