From 716174b5ac7f9b193dfcb7e29c80ae792c7c0c2f Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Tue, 9 Dec 2025 19:26:55 +0000 Subject: [PATCH 1/3] [Identity] Expand usage of regional authorities Currently, only MSAL confidential client based credentials allow passing in a region through the `AZURE_REGIONAL_AUTHORITY_NAME`. For the other non-MSAL based confidential flows, we should have parity with the MSAL ones. Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 2 + .../azure/identity/_internal/aad_client.py | 45 +++++++- .../identity/_internal/aad_client_base.py | 51 ++++++++- .../identity/aio/_internal/aad_client.py | 44 ++++++++ .../azure-identity/tests/test_aad_client.py | 75 +++++++++++++ .../tests/test_aad_client_async.py | 101 +++++++++++++++++- 6 files changed, 311 insertions(+), 7 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 442e17043e38..061a653e837e 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -12,6 +12,8 @@ ### Bugs Fixed +- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347)) + ### Other Changes ## 1.26.0b1 (2025-11-07) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index 02fb0c922f5e..8c01535847e3 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # ------------------------------------ import time +import logging +import os from typing import Iterable, Union, Optional, Any from azure.core.credentials import AccessTokenInfo @@ -11,9 +13,19 @@ from .aad_client_base import AadClientBase from .aadclient_certificate import AadClientCertificate from .pipeline import build_pipeline +from .._enums import RegionalAuthority -class AadClient(AadClientBase): +_LOGGER = logging.getLogger(__name__) + + +class AadClient(AadClientBase): # pylint:disable=client-accepts-api-version-keyword + + # pylint:disable=missing-client-constructor-parameter-credential + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._initialize_regional_authority() + def __enter__(self) -> "AadClient": self._pipeline.__enter__() return self @@ -62,6 +74,37 @@ def obtain_token_on_behalf_of( # no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL raise NotImplementedError() + def _initialize_regional_authority(self) -> None: + # This is based on MSAL's regional authority logic. + if self._regional_authority is not False: + return + + regional_authority = self._get_regional_authority_from_env() + if not regional_authority: + self._regional_authority = None + return + + if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION: + regional_authority = self._discover_region() + if not regional_authority: + _LOGGER.warning("Failed to auto-discover region. Using the non-regional authority.") + self._regional_authority = None + return + + self._regional_authority = self._build_regional_authority_url(regional_authority) + + def _discover_region(self) -> Optional[str]: + region = os.environ.get("REGION_NAME", "").replace(" ", "").lower() + if region: + return region + try: + request = self._get_region_discovery_request() + response = self._pipeline.run(request) + return self._process_region_discovery_response(response) + except Exception as ex: # pylint: disable=broad-except + _LOGGER.info("Failed to discover Azure region from IMDS: %s", ex) + return None + def _build_pipeline(self, **kwargs: Any) -> Pipeline: return build_pipeline(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index a6369a1e208e..876ba56a5c7e 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -5,8 +5,11 @@ import abc import base64 import json +import logging +import os import time from uuid import uuid4 +from urllib.parse import urlparse from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast from msal import TokenCache @@ -19,6 +22,7 @@ from .utils import get_default_authority, normalize_authority, resolve_tenant from .aadclient_certificate import AadClientCertificate from .._persistent_cache import _load_persistent_cache +from .._constants import EnvironmentVariables if TYPE_CHECKING: @@ -30,10 +34,12 @@ PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy] TransportType = Union[AsyncHttpTransport, HttpTransport] +_LOGGER = logging.getLogger(__name__) + JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" -class AadClientBase(abc.ABC): +class AadClientBase(abc.ABC): # pylint: disable=too-many-instance-attributes _POST = ["POST"] def __init__( @@ -45,10 +51,13 @@ def __init__( cae_cache: Optional[TokenCache] = None, *, additionally_allowed_tenants: Optional[List[str]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: self._authority = normalize_authority(authority) if authority else get_default_authority() + # False indicates uninitialized. Actual value is str or None. + self._regional_authority: Optional[Union[str, bool]] = False + self._tenant_id = tenant_id self._client_id = client_id self._additionally_allowed_tenants = additionally_allowed_tenants or [] @@ -293,7 +302,7 @@ def _get_on_behalf_of_request( scopes: Iterable[str], client_credential: Union[str, AadClientCertificate, Dict[str, Any]], user_assertion: str, - **kwargs: Any + **kwargs: Any, ) -> HttpRequest: data = { "assertion": user_assertion, @@ -348,7 +357,7 @@ def _get_refresh_token_on_behalf_of_request( scopes: Iterable[str], client_credential: Union[str, AadClientCertificate, Dict[str, Any]], refresh_token: str, - **kwargs: Any + **kwargs: Any, ) -> HttpRequest: data = { "grant_type": "refresh_token", @@ -375,11 +384,43 @@ def _get_refresh_token_on_behalf_of_request( request = self._post(data, **kwargs) return request + def _get_region_discovery_request(self) -> HttpRequest: + url = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=2021-01-01" + request = HttpRequest("GET", url, headers={"Metadata": "true"}) + return request + + def _process_region_discovery_response(self, response: PipelineResponse) -> Optional[str]: + if response.http_response.status_code == 200: + region = response.http_response.text().strip() + if region: + return region + _LOGGER.warning("IMDS returned empty region") + return None + + def _build_regional_authority_url(self, regional_authority: str) -> Optional[str]: + central_host = urlparse(self._authority).hostname + if not central_host: + return None + + # This mirrors the regional authority logic in MSAL. + if central_host in ("login.microsoftonline.com", "login.microsoft.com", "login.windows.net", "sts.windows.net"): + regional_host = f"{regional_authority}.login.microsoft.com" + else: + regional_host = f"{regional_authority}.{central_host}" + return f"https://{regional_host}" + + def _get_regional_authority_from_env(self) -> Optional[str]: + regional_authority = os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME) or os.environ.get( + "MSAL_FORCE_REGION" + ) # For parity with creds that rely on MSAL, we check this var too. + return regional_authority.lower() if regional_authority else None + def _get_token_url(self, **kwargs: Any) -> str: tenant = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs ) - return "/".join((self._authority, tenant, "oauth2/v2.0/token")) + authority = cast(str, self._regional_authority) if self._regional_authority else self._authority + return "/".join((authority, tenant, "oauth2/v2.0/token")) def _post(self, data: Dict, **kwargs: Any) -> HttpRequest: url = self._get_token_url(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 7b99f85ac912..2819b1d5f456 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import logging +import os import time from typing import Iterable, Optional, Union, Dict, Any @@ -12,9 +14,12 @@ from ..._internal import AadClientCertificate from ..._internal import AadClientBase from ..._internal.pipeline import build_async_pipeline +from ..._enums import RegionalAuthority Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy] +_LOGGER = logging.getLogger(__name__) + # pylint:disable=invalid-overridden-method class AadClient(AadClientBase): @@ -33,6 +38,7 @@ async def close(self) -> None: async def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs ) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -41,20 +47,24 @@ async def obtain_token_by_authorization_code( async def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs ) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_client_certificate_request(scopes, certificate, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_client_secret_request(scopes, secret, **kwargs) return await self._run_pipeline(request, **kwargs) async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return await self._run_pipeline(request, stream=False, **kwargs) async def obtain_token_by_refresh_token( self, scopes: Iterable[str], refresh_token: str, **kwargs ) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return await self._run_pipeline(request, **kwargs) @@ -65,6 +75,7 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to refresh_token: str, **kwargs ) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_refresh_token_on_behalf_of_request( scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs ) @@ -77,6 +88,7 @@ async def obtain_token_on_behalf_of( user_assertion: str, **kwargs ) -> AccessTokenInfo: + await self._initialize_regional_authority() request = self._get_on_behalf_of_request( scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs ) @@ -85,6 +97,38 @@ async def obtain_token_on_behalf_of( def _build_pipeline(self, **kwargs) -> AsyncPipeline: return build_async_pipeline(**kwargs) + async def _initialize_regional_authority(self) -> None: + # This is based on MSAL's regional authority logic. + if self._regional_authority is not False: + return + + regional_authority = self._get_regional_authority_from_env() + if not regional_authority: + self._regional_authority = None + return + + if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION: + # Attempt to discover the region from IMDS + regional_authority = await self._discover_region() + if not regional_authority: + _LOGGER.debug("Failed to auto-discover region. Using the non-regional authority.") + self._regional_authority = None + return + + self._regional_authority = self._build_regional_authority_url(regional_authority) + + async def _discover_region(self) -> Optional[str]: + region = os.environ.get("REGION_NAME", "").replace(" ", "").lower() + if region: + return region + try: + request = self._get_region_discovery_request() + response = await self._pipeline.run(request) + return self._process_region_discovery_response(response) + except Exception as ex: # pylint: disable=broad-except + _LOGGER.info("Failed to discover Azure region from IMDS: %s", ex) + return None + async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index b9f8c374d341..78f058667985 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import functools +from unittest import mock from unittest.mock import Mock, patch from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError @@ -113,6 +114,80 @@ def send(request, **_): client.obtain_token_by_refresh_token("scope", "refresh token") +def test_token_url(): + tenant_id = "tenant-id" + client = AadClient(tenant_id, "client-id", authority="https://login.microsoftonline.com") + assert client._get_token_url() == "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token" + + # Test with usage of AZURE_AUTHORITY_HOST + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.microsoftonline.com"}, clear=True + ): + client = AadClient(tenant_id=tenant_id, client_id="client-id") + assert client._get_token_url() == "https://custom.microsoftonline.com/tenant-id/oauth2/v2.0/token" + + # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): + client = AadClient(tenant_id=tenant_id, client_id="client-id") + assert client._get_token_url() == "https://centralus.login.microsoft.com/tenant-id/oauth2/v2.0/token" + + # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME and AZURE_AUTHORITY_HOST + with patch.dict( + "os.environ", + { + EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://login.microsoftonline.us", + EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus", + }, + clear=True, + ): + client = AadClient(tenant_id=tenant_id, client_id="client-id") + assert client._get_token_url() == "https://centralus.login.microsoftonline.us/tenant-id/oauth2/v2.0/token" + + +def test_initialize_regional_authority(): + client = AadClient("tenant-id", "client-id") + assert client._regional_authority is None + + # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): + client = AadClient("tenant-id", "client-id") + assert client._regional_authority == "https://centralus.login.microsoft.com" + + # Test with non-Microsoft authority host + with patch.dict( + "os.environ", + { + EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.authority.com", + EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus", + }, + clear=True, + ): + client = AadClient("tenant-id", "client-id") + assert client._regional_authority == "https://centralus.custom.authority.com" + + # Test with usage of region auto-discovery env var + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): + with patch.dict("os.environ", {"REGION_NAME": "eastus"}): + client = AadClient("tenant-id", "client-id") + assert client._regional_authority == "https://eastus.login.microsoft.com" + + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): + response = Mock( + status_code=200, + headers={"Content-Type": "text/plain"}, + content_type="text/plain", + text=lambda encoding=None: "westus2", + ) + transport = mock.Mock(send=mock.Mock(return_value=response)) + + client = AadClient("tenant-id", "client-id", transport=transport) + assert client._regional_authority == "https://westus2.login.microsoft.com" + + with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True): + client = AadClient("tenant-id", "client-id") + assert client._regional_authority == "https://westus3.login.microsoft.com" + + @pytest.mark.parametrize("secret", (None, "client secret")) def test_authorization_code(secret): tenant_id = "tenant-id" diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 56a2f1486a8c..b0a2586f2fed 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -14,7 +14,7 @@ import pytest from helpers import build_aad_response, mock_response -from helpers_async import get_completed_future +from helpers_async import get_completed_future, AsyncMockTransport from test_certificate_credential import PEM_CERT_PATH pytestmark = pytest.mark.asyncio @@ -189,6 +189,45 @@ async def send(request, **_): await client.obtain_token_by_refresh_token("scope", "refresh token") +async def test_request_url_with_regional_authority(): + + async def send(request, **_): + assert urlparse(request.url).netloc == "centralus.login.microsoft.com" + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): + client = AadClient("tenant-id", "client id", transport=Mock(send=send)) + + await client.obtain_token_by_authorization_code("scope", "code", "uri") + await client.obtain_token_by_refresh_token("scope", "refresh token") + + # obtain_token_by_refresh_token is client_secret safe + await client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") + + +async def test_regional_authority_initialized_once(): + """The client should lazily initialize its regional authority only once.""" + + async def send(request, **_): + assert urlparse(request.url).netloc == "centralus.login.microsoft.com" + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + # Mock _get_regional_authority_from_env to track how many times it's called. + with patch("azure.identity.aio._internal.aad_client.AadClient._get_regional_authority_from_env") as mock_env: + mock_env.return_value = "centralus" + transport = AsyncMockTransport(send=Mock(wraps=send)) + client = AadClient("tenant-id", "client id", transport=transport) + + # The first token request should trigger initialization. + await client.obtain_token_by_authorization_code("scope", "code", "uri") + # Subsequent requests shouldn't. + await client.obtain_token_by_refresh_token("scope", "refresh token") + await client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") + + # Env should be checked only once. + assert mock_env.call_count == 1 + + async def test_evicts_invalid_refresh_token(): """when Microsoft Entra ID rejects a refresh token, the client should evict that token from its cache""" @@ -324,3 +363,63 @@ async def test_multitenant_cache(): assert client_d.get_cached_access_token([scope]) is None with pytest.raises(ClientAuthenticationError, match=message): client_d.get_cached_access_token([scope], tenant_id=tenant_a) + + +async def test_initialize_regional_authority(): + client = AadClient("tenant-id", "client-id") + async with client: + await client._initialize_regional_authority() + assert client._regional_authority is None + + # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): + client = AadClient("tenant-id", "client-id") + async with client: + await client._initialize_regional_authority() + assert client._regional_authority == "https://centralus.login.microsoft.com" + + # Test with non-Microsoft authority host + with patch.dict( + "os.environ", + { + EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.authority.com", + EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus", + }, + clear=True, + ): + client = AadClient("tenant-id", "client-id") + async with client: + await client._initialize_regional_authority() + assert client._regional_authority == "https://centralus.custom.authority.com" + + # Test with usage of region auto-discovery env var + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): + with patch.dict("os.environ", {"REGION_NAME": "eastus"}): + client = AadClient("tenant-id", "client-id") + await client._initialize_regional_authority() + assert client._regional_authority == "https://eastus.login.microsoft.com" + await client.close() + + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): + response = Mock( + status_code=200, + headers={"Content-Type": "text/plain"}, + content_type="text/plain", + text=lambda encoding=None: "westus2", + ) + + async def send(*args, **kwargs): + return response + + transport = AsyncMockTransport(send=Mock(wraps=send)) + + client = AadClient("tenant-id", "client-id", transport=transport) + async with client: + await client._initialize_regional_authority() + assert client._regional_authority == "https://westus2.login.microsoft.com" + + with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True): + client = AadClient("tenant-id", "client-id") + async with client: + await client._initialize_regional_authority() + assert client._regional_authority == "https://westus3.login.microsoft.com" From 5226ddbb9c8c3528ab76cb05832179898589ddcb Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Tue, 9 Dec 2025 20:33:05 +0000 Subject: [PATCH 2/3] Updates Signed-off-by: Paul Van Eck --- .../azure/identity/_internal/aad_client.py | 10 ++- .../identity/aio/_internal/aad_client.py | 4 +- .../azure-identity/tests/test_aad_client.py | 65 ++++++++++++------- .../tests/test_aad_client_async.py | 4 +- 4 files changed, 53 insertions(+), 30 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index 8c01535847e3..dd327f2695fc 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -24,7 +24,6 @@ class AadClient(AadClientBase): # pylint:disable=client-accepts-api-version-key # pylint:disable=missing-client-constructor-parameter-credential def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._initialize_regional_authority() def __enter__(self) -> "AadClient": self._pipeline.__enter__() @@ -39,6 +38,7 @@ def close(self) -> None: def obtain_token_by_authorization_code( self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs: Any ) -> AccessTokenInfo: + self._initialize_regional_authority() request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs ) @@ -47,20 +47,24 @@ def obtain_token_by_authorization_code( def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs: Any ) -> AccessTokenInfo: + self._initialize_regional_authority() request = self._get_client_certificate_request(scopes, certificate, **kwargs) return self._run_pipeline(request, **kwargs) def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessTokenInfo: + self._initialize_regional_authority() request = self._get_client_secret_request(scopes, secret, **kwargs) return self._run_pipeline(request, **kwargs) def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessTokenInfo: + self._initialize_regional_authority() request = self._get_jwt_assertion_request(scopes, assertion, **kwargs) return self._run_pipeline(request, **kwargs) def obtain_token_by_refresh_token( self, scopes: Iterable[str], refresh_token: str, **kwargs: Any ) -> AccessTokenInfo: + self._initialize_regional_authority() request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return self._run_pipeline(request, **kwargs) @@ -87,7 +91,7 @@ def _initialize_regional_authority(self) -> None: if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION: regional_authority = self._discover_region() if not regional_authority: - _LOGGER.warning("Failed to auto-discover region. Using the non-regional authority.") + _LOGGER.info("Failed to auto-discover region. Using the non-regional authority.") self._regional_authority = None return @@ -102,7 +106,7 @@ def _discover_region(self) -> Optional[str]: response = self._pipeline.run(request) return self._process_region_discovery_response(response) except Exception as ex: # pylint: disable=broad-except - _LOGGER.info("Failed to discover Azure region from IMDS: %s", ex) + _LOGGER.debug("Failed to discover Azure region from IMDS: %s", ex) return None def _build_pipeline(self, **kwargs: Any) -> Pipeline: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 2819b1d5f456..686d6ae71e17 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -111,7 +111,7 @@ async def _initialize_regional_authority(self) -> None: # Attempt to discover the region from IMDS regional_authority = await self._discover_region() if not regional_authority: - _LOGGER.debug("Failed to auto-discover region. Using the non-regional authority.") + _LOGGER.info("Failed to auto-discover region. Using the non-regional authority.") self._regional_authority = None return @@ -126,7 +126,7 @@ async def _discover_region(self) -> Optional[str]: response = await self._pipeline.run(request) return self._process_region_discovery_response(response) except Exception as ex: # pylint: disable=broad-except - _LOGGER.info("Failed to discover Azure region from IMDS: %s", ex) + _LOGGER.debug("Failed to discover Azure region from IMDS: %s", ex) return None async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo: diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 78f058667985..24f9e5de9508 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -114,43 +114,56 @@ def send(request, **_): client.obtain_token_by_refresh_token("scope", "refresh token") -def test_token_url(): - tenant_id = "tenant-id" - client = AadClient(tenant_id, "client-id", authority="https://login.microsoftonline.com") - assert client._get_token_url() == "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token" +def test_request_url_with_regional_authority(): - # Test with usage of AZURE_AUTHORITY_HOST - with patch.dict( - "os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.microsoftonline.com"}, clear=True - ): - client = AadClient(tenant_id=tenant_id, client_id="client-id") - assert client._get_token_url() == "https://custom.microsoftonline.com/tenant-id/oauth2/v2.0/token" + def send(request, **_): + assert urlparse(request.url).netloc == "centralus.login.microsoft.com" + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): - client = AadClient(tenant_id=tenant_id, client_id="client-id") - assert client._get_token_url() == "https://centralus.login.microsoft.com/tenant-id/oauth2/v2.0/token" + client = AadClient("tenant-id", "client id", transport=Mock(send=send)) + + client.obtain_token_by_authorization_code("scope", "code", "uri") + client.obtain_token_by_refresh_token("scope", "refresh token") + + # obtain_token_by_refresh_token is client_secret safe + client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") - # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME and AZURE_AUTHORITY_HOST - with patch.dict( - "os.environ", - { - EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://login.microsoftonline.us", - EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus", - }, - clear=True, - ): - client = AadClient(tenant_id=tenant_id, client_id="client-id") - assert client._get_token_url() == "https://centralus.login.microsoftonline.us/tenant-id/oauth2/v2.0/token" + +def test_regional_authority_initialized_once(): + """The client should lazily initialize its regional authority only once.""" + + def send(request, **_): + assert urlparse(request.url).netloc == "centralus.login.microsoft.com" + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + with patch("azure.identity._internal.aad_client.AadClient._get_regional_authority_from_env") as mock_env: + mock_env.return_value = "centralus" + transport = Mock(send=Mock(wraps=send)) + client = AadClient("tenant-id", "client id", transport=transport) + + # The first token request should trigger initialization. + client.obtain_token_by_authorization_code("scope", "code", "uri") + # Subsequent requests shouldn't. + client.obtain_token_by_refresh_token("scope", "refresh token") + client.obtain_token_by_refresh_token("scope", "refresh token", client_secret="secret") + + # Env should be checked only once. + assert mock_env.call_count == 1 def test_initialize_regional_authority(): client = AadClient("tenant-id", "client-id") + # The initial state should be False (uninitialized) + assert client._regional_authority is False + + client._initialize_regional_authority() assert client._regional_authority is None # Test with usage of AZURE_REGIONAL_AUTHORITY_NAME with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True): client = AadClient("tenant-id", "client-id") + client._initialize_regional_authority() assert client._regional_authority == "https://centralus.login.microsoft.com" # Test with non-Microsoft authority host @@ -163,12 +176,14 @@ def test_initialize_regional_authority(): clear=True, ): client = AadClient("tenant-id", "client-id") + client._initialize_regional_authority() assert client._regional_authority == "https://centralus.custom.authority.com" # Test with usage of region auto-discovery env var with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): with patch.dict("os.environ", {"REGION_NAME": "eastus"}): client = AadClient("tenant-id", "client-id") + client._initialize_regional_authority() assert client._regional_authority == "https://eastus.login.microsoft.com" with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): @@ -181,10 +196,12 @@ def test_initialize_regional_authority(): transport = mock.Mock(send=mock.Mock(return_value=response)) client = AadClient("tenant-id", "client-id", transport=transport) + client._initialize_regional_authority() assert client._regional_authority == "https://westus2.login.microsoft.com" with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True): client = AadClient("tenant-id", "client-id") + client._initialize_regional_authority() assert client._regional_authority == "https://westus3.login.microsoft.com" diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index b0a2586f2fed..365e701e30d2 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -212,7 +212,6 @@ async def send(request, **_): assert urlparse(request.url).netloc == "centralus.login.microsoft.com" return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - # Mock _get_regional_authority_from_env to track how many times it's called. with patch("azure.identity.aio._internal.aad_client.AadClient._get_regional_authority_from_env") as mock_env: mock_env.return_value = "centralus" transport = AsyncMockTransport(send=Mock(wraps=send)) @@ -367,6 +366,9 @@ async def test_multitenant_cache(): async def test_initialize_regional_authority(): client = AadClient("tenant-id", "client-id") + # The initial state should be False (uninitialized) + assert client._regional_authority is False + async with client: await client._initialize_regional_authority() assert client._regional_authority is None From 8a468858b374f7a3002821df0253ed3efe52052a Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Wed, 10 Dec 2025 01:16:27 +0000 Subject: [PATCH 3/3] Add support for True Signed-off-by: Paul Van Eck --- .../azure-identity/azure/identity/_internal/aad_client.py | 2 +- .../azure/identity/aio/_internal/aad_client.py | 2 +- sdk/identity/azure-identity/tests/test_aad_client.py | 8 ++++++++ .../azure-identity/tests/test_aad_client_async.py | 8 ++++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index dd327f2695fc..c6690033697d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -88,7 +88,7 @@ def _initialize_regional_authority(self) -> None: self._regional_authority = None return - if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION: + if regional_authority in [RegionalAuthority.AUTO_DISCOVER_REGION, "true"]: regional_authority = self._discover_region() if not regional_authority: _LOGGER.info("Failed to auto-discover region. Using the non-regional authority.") diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 686d6ae71e17..f8f434dfe3fd 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -107,7 +107,7 @@ async def _initialize_regional_authority(self) -> None: self._regional_authority = None return - if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION: + if regional_authority in [RegionalAuthority.AUTO_DISCOVER_REGION, "true"]: # Attempt to discover the region from IMDS regional_authority = await self._discover_region() if not regional_authority: diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 24f9e5de9508..6facdafb4629 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -179,6 +179,14 @@ def test_initialize_regional_authority(): client._initialize_regional_authority() assert client._regional_authority == "https://centralus.custom.authority.com" + # Test with usage of region auto-discovery env var + # Test with AZURE_REGIONAL_AUTHORITY_NAME set to "True" (auto-discovery) + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "True"}, clear=True): + with patch.dict("os.environ", {"REGION_NAME": "southcentralus"}): + client = AadClient("tenant-id", "client-id") + client._initialize_regional_authority() + assert client._regional_authority == "https://southcentralus.login.microsoft.com" + # Test with usage of region auto-discovery env var with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): with patch.dict("os.environ", {"REGION_NAME": "eastus"}): diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 365e701e30d2..d10357e545df 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -394,6 +394,14 @@ async def test_initialize_regional_authority(): await client._initialize_regional_authority() assert client._regional_authority == "https://centralus.custom.authority.com" + # Test with AZURE_REGIONAL_AUTHORITY_NAME set to "True" (auto-discovery) + with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "True"}, clear=True): + with patch.dict("os.environ", {"REGION_NAME": "southcentralus"}): + client = AadClient("tenant-id", "client-id") + await client._initialize_regional_authority() + assert client._regional_authority == "https://southcentralus.login.microsoft.com" + await client.close() + # Test with usage of region auto-discovery env var with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True): with patch.dict("os.environ", {"REGION_NAME": "eastus"}):