Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

### Bugs Fixed

- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347))

### Other Changes

## 1.26.0b1 (2025-11-07)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License.
# ------------------------------------
import time
import logging
import os
from typing import Iterable, Union, Optional, Any

from azure.core.credentials import AccessTokenInfo
Expand All @@ -11,9 +13,19 @@
from .aad_client_base import AadClientBase
from .aadclient_certificate import AadClientCertificate
from .pipeline import build_pipeline
from .._enums import RegionalAuthority


class AadClient(AadClientBase):
_LOGGER = logging.getLogger(__name__)


class AadClient(AadClientBase): # pylint:disable=client-accepts-api-version-keyword

# pylint:disable=missing-client-constructor-parameter-credential
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._initialize_regional_authority()

def __enter__(self) -> "AadClient":
self._pipeline.__enter__()
return self
Expand Down Expand Up @@ -62,6 +74,37 @@ def obtain_token_on_behalf_of(
# no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL
raise NotImplementedError()

def _initialize_regional_authority(self) -> None:
# This is based on MSAL's regional authority logic.
if self._regional_authority is not False:
return

regional_authority = self._get_regional_authority_from_env()
if not regional_authority:
self._regional_authority = None
return

if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION:
regional_authority = self._discover_region()
if not regional_authority:
_LOGGER.warning("Failed to auto-discover region. Using the non-regional authority.")
self._regional_authority = None
return

self._regional_authority = self._build_regional_authority_url(regional_authority)

def _discover_region(self) -> Optional[str]:
region = os.environ.get("REGION_NAME", "").replace(" ", "").lower()
if region:
return region
try:
request = self._get_region_discovery_request()
response = self._pipeline.run(request)
return self._process_region_discovery_response(response)
except Exception as ex: # pylint: disable=broad-except
_LOGGER.info("Failed to discover Azure region from IMDS: %s", ex)
return None

def _build_pipeline(self, **kwargs: Any) -> Pipeline:
return build_pipeline(**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import abc
import base64
import json
import logging
import os
import time
from uuid import uuid4
from urllib.parse import urlparse
from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast

from msal import TokenCache
Expand All @@ -19,6 +22,7 @@
from .utils import get_default_authority, normalize_authority, resolve_tenant
from .aadclient_certificate import AadClientCertificate
from .._persistent_cache import _load_persistent_cache
from .._constants import EnvironmentVariables


if TYPE_CHECKING:
Expand All @@ -30,10 +34,12 @@
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
TransportType = Union[AsyncHttpTransport, HttpTransport]

_LOGGER = logging.getLogger(__name__)

JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"


class AadClientBase(abc.ABC):
class AadClientBase(abc.ABC): # pylint: disable=too-many-instance-attributes
_POST = ["POST"]

def __init__(
Expand All @@ -45,10 +51,13 @@ def __init__(
cae_cache: Optional[TokenCache] = None,
*,
additionally_allowed_tenants: Optional[List[str]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self._authority = normalize_authority(authority) if authority else get_default_authority()

# False indicates uninitialized. Actual value is str or None.
self._regional_authority: Optional[Union[str, bool]] = False

self._tenant_id = tenant_id
self._client_id = client_id
self._additionally_allowed_tenants = additionally_allowed_tenants or []
Expand Down Expand Up @@ -293,7 +302,7 @@ def _get_on_behalf_of_request(
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs: Any
**kwargs: Any,
) -> HttpRequest:
data = {
"assertion": user_assertion,
Expand Down Expand Up @@ -348,7 +357,7 @@ def _get_refresh_token_on_behalf_of_request(
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs: Any
**kwargs: Any,
) -> HttpRequest:
data = {
"grant_type": "refresh_token",
Expand All @@ -375,11 +384,43 @@ def _get_refresh_token_on_behalf_of_request(
request = self._post(data, **kwargs)
return request

def _get_region_discovery_request(self) -> HttpRequest:
url = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=2021-01-01"
request = HttpRequest("GET", url, headers={"Metadata": "true"})
return request

def _process_region_discovery_response(self, response: PipelineResponse) -> Optional[str]:
if response.http_response.status_code == 200:
region = response.http_response.text().strip()
if region:
return region
_LOGGER.warning("IMDS returned empty region")
return None

def _build_regional_authority_url(self, regional_authority: str) -> Optional[str]:
central_host = urlparse(self._authority).hostname
if not central_host:
return None

# This mirrors the regional authority logic in MSAL.
if central_host in ("login.microsoftonline.com", "login.microsoft.com", "login.windows.net", "sts.windows.net"):
regional_host = f"{regional_authority}.login.microsoft.com"
else:
regional_host = f"{regional_authority}.{central_host}"
return f"https://{regional_host}"

def _get_regional_authority_from_env(self) -> Optional[str]:
regional_authority = os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME) or os.environ.get(
"MSAL_FORCE_REGION"
) # For parity with creds that rely on MSAL, we check this var too.
return regional_authority.lower() if regional_authority else None

def _get_token_url(self, **kwargs: Any) -> str:
tenant = resolve_tenant(
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
)
return "/".join((self._authority, tenant, "oauth2/v2.0/token"))
authority = cast(str, self._regional_authority) if self._regional_authority else self._authority
return "/".join((authority, tenant, "oauth2/v2.0/token"))

def _post(self, data: Dict, **kwargs: Any) -> HttpRequest:
url = self._get_token_url(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import logging
import os
import time
from typing import Iterable, Optional, Union, Dict, Any

Expand All @@ -12,9 +14,12 @@
from ..._internal import AadClientCertificate
from ..._internal import AadClientBase
from ..._internal.pipeline import build_async_pipeline
from ..._enums import RegionalAuthority

Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy]

_LOGGER = logging.getLogger(__name__)


# pylint:disable=invalid-overridden-method
class AadClient(AadClientBase):
Expand All @@ -33,6 +38,7 @@ async def close(self) -> None:
async def obtain_token_by_authorization_code(
self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
)
Expand All @@ -41,20 +47,24 @@ async def obtain_token_by_authorization_code(
async def obtain_token_by_client_certificate(
self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
return await self._run_pipeline(request, stream=False, **kwargs)

async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_client_secret_request(scopes, secret, **kwargs)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_jwt_assertion_request(scopes, assertion, **kwargs)
return await self._run_pipeline(request, stream=False, **kwargs)

async def obtain_token_by_refresh_token(
self, scopes: Iterable[str], refresh_token: str, **kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
return await self._run_pipeline(request, **kwargs)

Expand All @@ -65,6 +75,7 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to
refresh_token: str,
**kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_refresh_token_on_behalf_of_request(
scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs
)
Expand All @@ -77,6 +88,7 @@ async def obtain_token_on_behalf_of(
user_assertion: str,
**kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_on_behalf_of_request(
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
)
Expand All @@ -85,6 +97,38 @@ async def obtain_token_on_behalf_of(
def _build_pipeline(self, **kwargs) -> AsyncPipeline:
return build_async_pipeline(**kwargs)

async def _initialize_regional_authority(self) -> None:
# This is based on MSAL's regional authority logic.
if self._regional_authority is not False:
return

regional_authority = self._get_regional_authority_from_env()
if not regional_authority:
self._regional_authority = None
return

if regional_authority == RegionalAuthority.AUTO_DISCOVER_REGION:
# Attempt to discover the region from IMDS
regional_authority = await self._discover_region()
if not regional_authority:
_LOGGER.debug("Failed to auto-discover region. Using the non-regional authority.")
self._regional_authority = None
return

self._regional_authority = self._build_regional_authority_url(regional_authority)

async def _discover_region(self) -> Optional[str]:
region = os.environ.get("REGION_NAME", "").replace(" ", "").lower()
if region:
return region
try:
request = self._get_region_discovery_request()
response = await self._pipeline.run(request)
return self._process_region_discovery_response(response)
except Exception as ex: # pylint: disable=broad-except
_LOGGER.info("Failed to discover Azure region from IMDS: %s", ex)
return None

async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo:
# remove tenant_id and claims kwarg that could have been passed from credential's get_token method
# tenant_id is already part of `request` at this point
Expand Down
75 changes: 75 additions & 0 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import functools
from unittest import mock
from unittest.mock import Mock, patch

from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError
Expand Down Expand Up @@ -113,6 +114,80 @@ def send(request, **_):
client.obtain_token_by_refresh_token("scope", "refresh token")


def test_token_url():
tenant_id = "tenant-id"
client = AadClient(tenant_id, "client-id", authority="https://login.microsoftonline.com")
assert client._get_token_url() == "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token"

# Test with usage of AZURE_AUTHORITY_HOST
with patch.dict(
"os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.microsoftonline.com"}, clear=True
):
client = AadClient(tenant_id=tenant_id, client_id="client-id")
assert client._get_token_url() == "https://custom.microsoftonline.com/tenant-id/oauth2/v2.0/token"

# Test with usage of AZURE_REGIONAL_AUTHORITY_NAME
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True):
client = AadClient(tenant_id=tenant_id, client_id="client-id")
assert client._get_token_url() == "https://centralus.login.microsoft.com/tenant-id/oauth2/v2.0/token"

# Test with usage of AZURE_REGIONAL_AUTHORITY_NAME and AZURE_AUTHORITY_HOST
with patch.dict(
"os.environ",
{
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://login.microsoftonline.us",
EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus",
},
clear=True,
):
client = AadClient(tenant_id=tenant_id, client_id="client-id")
assert client._get_token_url() == "https://centralus.login.microsoftonline.us/tenant-id/oauth2/v2.0/token"


def test_initialize_regional_authority():
client = AadClient("tenant-id", "client-id")
assert client._regional_authority is None

# Test with usage of AZURE_REGIONAL_AUTHORITY_NAME
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus"}, clear=True):
client = AadClient("tenant-id", "client-id")
assert client._regional_authority == "https://centralus.login.microsoft.com"

# Test with non-Microsoft authority host
with patch.dict(
"os.environ",
{
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://custom.authority.com",
EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "centralus",
},
clear=True,
):
client = AadClient("tenant-id", "client-id")
assert client._regional_authority == "https://centralus.custom.authority.com"

# Test with usage of region auto-discovery env var
with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True):
with patch.dict("os.environ", {"REGION_NAME": "eastus"}):
client = AadClient("tenant-id", "client-id")
assert client._regional_authority == "https://eastus.login.microsoft.com"

with patch.dict("os.environ", {EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME: "tryautodetect"}, clear=True):
response = Mock(
status_code=200,
headers={"Content-Type": "text/plain"},
content_type="text/plain",
text=lambda encoding=None: "westus2",
)
transport = mock.Mock(send=mock.Mock(return_value=response))

client = AadClient("tenant-id", "client-id", transport=transport)
assert client._regional_authority == "https://westus2.login.microsoft.com"

with patch.dict("os.environ", {"MSAL_FORCE_REGION": "westus3"}, clear=True):
client = AadClient("tenant-id", "client-id")
assert client._regional_authority == "https://westus3.login.microsoft.com"


@pytest.mark.parametrize("secret", (None, "client secret"))
def test_authorization_code(secret):
tenant_id = "tenant-id"
Expand Down
Loading
Loading