Skip to content

Commit 716174b

Browse files
committed
[Identity] Expand usage of regional authorities
Currently, only MSAL confidential client based credentials allow passing in a region through the `AZURE_REGIONAL_AUTHORITY_NAME`. For the other non-MSAL based confidential flows, we should have parity with the MSAL ones. Signed-off-by: Paul Van Eck <[email protected]>
1 parent 4aa385a commit 716174b

File tree

6 files changed

+311
-7
lines changed

6 files changed

+311
-7
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
### Bugs Fixed
1414

15+
- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347))
16+
1517
### Other Changes
1618

1719
## 1.26.0b1 (2025-11-07)

sdk/identity/azure-identity/azure/identity/_internal/aad_client.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import time
6+
import logging
7+
import os
68
from typing import Iterable, Union, Optional, Any
79

810
from azure.core.credentials import AccessTokenInfo
@@ -11,9 +13,19 @@
1113
from .aad_client_base import AadClientBase
1214
from .aadclient_certificate import AadClientCertificate
1315
from .pipeline import build_pipeline
16+
from .._enums import RegionalAuthority
1417

1518

16-
class AadClient(AadClientBase):
19+
_LOGGER = logging.getLogger(__name__)
20+
21+
22+
class AadClient(AadClientBase): # pylint:disable=client-accepts-api-version-keyword
23+
24+
# pylint:disable=missing-client-constructor-parameter-credential
25+
def __init__(self, *args: Any, **kwargs: Any) -> None:
26+
super().__init__(*args, **kwargs)
27+
self._initialize_regional_authority()
28+
1729
def __enter__(self) -> "AadClient":
1830
self._pipeline.__enter__()
1931
return self
@@ -62,6 +74,37 @@ def obtain_token_on_behalf_of(
6274
# no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL
6375
raise NotImplementedError()
6476

77+
def _initialize_regional_authority(self) -> None:
78+
# This is based on MSAL's regional authority logic.
79+
if self._regional_authority is not False:
80+
return
81+
82+
regional_authority = self._get_regional_authority_from_env()
83+
if not regional_authority:
84+
self._regional_authority = None
85+
return
86+
87+
if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION:
88+
regional_authority = self._discover_region()
89+
if not regional_authority:
90+
_LOGGER.warning("Failed to auto-discover region. Using the non-regional authority.")
91+
self._regional_authority = None
92+
return
93+
94+
self._regional_authority = self._build_regional_authority_url(regional_authority)
95+
96+
def _discover_region(self) -> Optional[str]:
97+
region = os.environ.get("REGION_NAME", "").replace(" ", "").lower()
98+
if region:
99+
return region
100+
try:
101+
request = self._get_region_discovery_request()
102+
response = self._pipeline.run(request)
103+
return self._process_region_discovery_response(response)
104+
except Exception as ex: # pylint: disable=broad-except
105+
_LOGGER.info("Failed to discover Azure region from IMDS: %s", ex)
106+
return None
107+
65108
def _build_pipeline(self, **kwargs: Any) -> Pipeline:
66109
return build_pipeline(**kwargs)
67110

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
import abc
66
import base64
77
import json
8+
import logging
9+
import os
810
import time
911
from uuid import uuid4
12+
from urllib.parse import urlparse
1013
from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast
1114

1215
from msal import TokenCache
@@ -19,6 +22,7 @@
1922
from .utils import get_default_authority, normalize_authority, resolve_tenant
2023
from .aadclient_certificate import AadClientCertificate
2124
from .._persistent_cache import _load_persistent_cache
25+
from .._constants import EnvironmentVariables
2226

2327

2428
if TYPE_CHECKING:
@@ -30,10 +34,12 @@
3034
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
3135
TransportType = Union[AsyncHttpTransport, HttpTransport]
3236

37+
_LOGGER = logging.getLogger(__name__)
38+
3339
JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
3440

3541

36-
class AadClientBase(abc.ABC):
42+
class AadClientBase(abc.ABC): # pylint: disable=too-many-instance-attributes
3743
_POST = ["POST"]
3844

3945
def __init__(
@@ -45,10 +51,13 @@ def __init__(
4551
cae_cache: Optional[TokenCache] = None,
4652
*,
4753
additionally_allowed_tenants: Optional[List[str]] = None,
48-
**kwargs: Any
54+
**kwargs: Any,
4955
) -> None:
5056
self._authority = normalize_authority(authority) if authority else get_default_authority()
5157

58+
# False indicates uninitialized. Actual value is str or None.
59+
self._regional_authority: Optional[Union[str, bool]] = False
60+
5261
self._tenant_id = tenant_id
5362
self._client_id = client_id
5463
self._additionally_allowed_tenants = additionally_allowed_tenants or []
@@ -293,7 +302,7 @@ def _get_on_behalf_of_request(
293302
scopes: Iterable[str],
294303
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
295304
user_assertion: str,
296-
**kwargs: Any
305+
**kwargs: Any,
297306
) -> HttpRequest:
298307
data = {
299308
"assertion": user_assertion,
@@ -348,7 +357,7 @@ def _get_refresh_token_on_behalf_of_request(
348357
scopes: Iterable[str],
349358
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
350359
refresh_token: str,
351-
**kwargs: Any
360+
**kwargs: Any,
352361
) -> HttpRequest:
353362
data = {
354363
"grant_type": "refresh_token",
@@ -375,11 +384,43 @@ def _get_refresh_token_on_behalf_of_request(
375384
request = self._post(data, **kwargs)
376385
return request
377386

387+
def _get_region_discovery_request(self) -> HttpRequest:
388+
url = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=2021-01-01"
389+
request = HttpRequest("GET", url, headers={"Metadata": "true"})
390+
return request
391+
392+
def _process_region_discovery_response(self, response: PipelineResponse) -> Optional[str]:
393+
if response.http_response.status_code == 200:
394+
region = response.http_response.text().strip()
395+
if region:
396+
return region
397+
_LOGGER.warning("IMDS returned empty region")
398+
return None
399+
400+
def _build_regional_authority_url(self, regional_authority: str) -> Optional[str]:
401+
central_host = urlparse(self._authority).hostname
402+
if not central_host:
403+
return None
404+
405+
# This mirrors the regional authority logic in MSAL.
406+
if central_host in ("login.microsoftonline.com", "login.microsoft.com", "login.windows.net", "sts.windows.net"):
407+
regional_host = f"{regional_authority}.login.microsoft.com"
408+
else:
409+
regional_host = f"{regional_authority}.{central_host}"
410+
return f"https://{regional_host}"
411+
412+
def _get_regional_authority_from_env(self) -> Optional[str]:
413+
regional_authority = os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME) or os.environ.get(
414+
"MSAL_FORCE_REGION"
415+
) # For parity with creds that rely on MSAL, we check this var too.
416+
return regional_authority.lower() if regional_authority else None
417+
378418
def _get_token_url(self, **kwargs: Any) -> str:
379419
tenant = resolve_tenant(
380420
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
381421
)
382-
return "/".join((self._authority, tenant, "oauth2/v2.0/token"))
422+
authority = cast(str, self._regional_authority) if self._regional_authority else self._authority
423+
return "/".join((authority, tenant, "oauth2/v2.0/token"))
383424

384425
def _post(self, data: Dict, **kwargs: Any) -> HttpRequest:
385426
url = self._get_token_url(**kwargs)

sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
import logging
6+
import os
57
import time
68
from typing import Iterable, Optional, Union, Dict, Any
79

@@ -12,9 +14,12 @@
1214
from ..._internal import AadClientCertificate
1315
from ..._internal import AadClientBase
1416
from ..._internal.pipeline import build_async_pipeline
17+
from ..._enums import RegionalAuthority
1518

1619
Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy]
1720

21+
_LOGGER = logging.getLogger(__name__)
22+
1823

1924
# pylint:disable=invalid-overridden-method
2025
class AadClient(AadClientBase):
@@ -33,6 +38,7 @@ async def close(self) -> None:
3338
async def obtain_token_by_authorization_code(
3439
self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs
3540
) -> AccessTokenInfo:
41+
await self._initialize_regional_authority()
3642
request = self._get_auth_code_request(
3743
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
3844
)
@@ -41,20 +47,24 @@ async def obtain_token_by_authorization_code(
4147
async def obtain_token_by_client_certificate(
4248
self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs
4349
) -> AccessTokenInfo:
50+
await self._initialize_regional_authority()
4451
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
4552
return await self._run_pipeline(request, stream=False, **kwargs)
4653

4754
async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo:
55+
await self._initialize_regional_authority()
4856
request = self._get_client_secret_request(scopes, secret, **kwargs)
4957
return await self._run_pipeline(request, **kwargs)
5058

5159
async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo:
60+
await self._initialize_regional_authority()
5261
request = self._get_jwt_assertion_request(scopes, assertion, **kwargs)
5362
return await self._run_pipeline(request, stream=False, **kwargs)
5463

5564
async def obtain_token_by_refresh_token(
5665
self, scopes: Iterable[str], refresh_token: str, **kwargs
5766
) -> AccessTokenInfo:
67+
await self._initialize_regional_authority()
5868
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
5969
return await self._run_pipeline(request, **kwargs)
6070

@@ -65,6 +75,7 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to
6575
refresh_token: str,
6676
**kwargs
6777
) -> AccessTokenInfo:
78+
await self._initialize_regional_authority()
6879
request = self._get_refresh_token_on_behalf_of_request(
6980
scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs
7081
)
@@ -77,6 +88,7 @@ async def obtain_token_on_behalf_of(
7788
user_assertion: str,
7889
**kwargs
7990
) -> AccessTokenInfo:
91+
await self._initialize_regional_authority()
8092
request = self._get_on_behalf_of_request(
8193
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
8294
)
@@ -85,6 +97,38 @@ async def obtain_token_on_behalf_of(
8597
def _build_pipeline(self, **kwargs) -> AsyncPipeline:
8698
return build_async_pipeline(**kwargs)
8799

100+
async def _initialize_regional_authority(self) -> None:
101+
# This is based on MSAL's regional authority logic.
102+
if self._regional_authority is not False:
103+
return
104+
105+
regional_authority = self._get_regional_authority_from_env()
106+
if not regional_authority:
107+
self._regional_authority = None
108+
return
109+
110+
if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION:
111+
# Attempt to discover the region from IMDS
112+
regional_authority = await self._discover_region()
113+
if not regional_authority:
114+
_LOGGER.debug("Failed to auto-discover region. Using the non-regional authority.")
115+
self._regional_authority = None
116+
return
117+
118+
self._regional_authority = self._build_regional_authority_url(regional_authority)
119+
120+
async def _discover_region(self) -> Optional[str]:
121+
region = os.environ.get("REGION_NAME", "").replace(" ", "").lower()
122+
if region:
123+
return region
124+
try:
125+
request = self._get_region_discovery_request()
126+
response = await self._pipeline.run(request)
127+
return self._process_region_discovery_response(response)
128+
except Exception as ex: # pylint: disable=broad-except
129+
_LOGGER.info("Failed to discover Azure region from IMDS: %s", ex)
130+
return None
131+
88132
async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo:
89133
# remove tenant_id and claims kwarg that could have been passed from credential's get_token method
90134
# tenant_id is already part of `request` at this point

sdk/identity/azure-identity/tests/test_aad_client.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import functools
6+
from unittest import mock
67
from unittest.mock import Mock, patch
78

89
from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError
@@ -113,6 +114,80 @@ def send(request, **_):
113114
client.obtain_token_by_refresh_token("scope", "refresh token")
114115

115116

117+
def test_token_url():
118+
tenant_id = "tenant-id"
119+
client = AadClient(tenant_id, "client-id", authority="https://login.microsoftonline.com")
120+
assert client._get_token_url() == "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token"
121+
122+
# Test with usage of AZURE_AUTHORITY_HOST
123+
with patch.dict(
124+
"os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.microsoftonline.com"}, clear=True
125+
):
126+
client = AadClient(tenant_id=tenant_id, client_id="client-id")
127+
assert client._get_token_url() == "https://custom.microsoftonline.com/tenant-id/oauth2/v2.0/token"
128+
129+
# Test with usage of AZURE_REGIONAL_AUTHORITY_NAME
130+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True):
131+
client = AadClient(tenant_id=tenant_id, client_id="client-id")
132+
assert client._get_token_url() == "https://centralus.login.microsoft.com/tenant-id/oauth2/v2.0/token"
133+
134+
# Test with usage of AZURE_REGIONAL_AUTHORITY_NAME and AZURE_AUTHORITY_HOST
135+
with patch.dict(
136+
"os.environ",
137+
{
138+
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://login.microsoftonline.us",
139+
EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus",
140+
},
141+
clear=True,
142+
):
143+
client = AadClient(tenant_id=tenant_id, client_id="client-id")
144+
assert client._get_token_url() == "https://centralus.login.microsoftonline.us/tenant-id/oauth2/v2.0/token"
145+
146+
147+
def test_initialize_regional_authority():
148+
client = AadClient("tenant-id", "client-id")
149+
assert client._regional_authority is None
150+
151+
# Test with usage of AZURE_REGIONAL_AUTHORITY_NAME
152+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True):
153+
client = AadClient("tenant-id", "client-id")
154+
assert client._regional_authority == "https://centralus.login.microsoft.com"
155+
156+
# Test with non-Microsoft authority host
157+
with patch.dict(
158+
"os.environ",
159+
{
160+
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.authority.com",
161+
EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus",
162+
},
163+
clear=True,
164+
):
165+
client = AadClient("tenant-id", "client-id")
166+
assert client._regional_authority == "https://centralus.custom.authority.com"
167+
168+
# Test with usage of region auto-discovery env var
169+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True):
170+
with patch.dict("os.environ", {"REGION_NAME": "eastus"}):
171+
client = AadClient("tenant-id", "client-id")
172+
assert client._regional_authority == "https://eastus.login.microsoft.com"
173+
174+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True):
175+
response = Mock(
176+
status_code=200,
177+
headers={"Content-Type": "text/plain"},
178+
content_type="text/plain",
179+
text=lambda encoding=None: "westus2",
180+
)
181+
transport = mock.Mock(send=mock.Mock(return_value=response))
182+
183+
client = AadClient("tenant-id", "client-id", transport=transport)
184+
assert client._regional_authority == "https://westus2.login.microsoft.com"
185+
186+
with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True):
187+
client = AadClient("tenant-id", "client-id")
188+
assert client._regional_authority == "https://westus3.login.microsoft.com"
189+
190+
116191
@pytest.mark.parametrize("secret", (None, "client secret"))
117192
def test_authorization_code(secret):
118193
tenant_id = "tenant-id"

0 commit comments

Comments
 (0)