Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

### Bugs Fixed

- Fixed the `AZURE_REGIONAL_AUTHORITY_NAME` environment variable not being respected in certain credentials. ([#44347](https://github.com/Azure/azure-sdk-for-python/pull/44347))

### Other Changes

## 1.26.0b1 (2025-11-07)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License.
# ------------------------------------
import time
import logging
import os
from typing import Iterable, Union, Optional, Any

from azure.core.credentials import AccessTokenInfo
Expand All @@ -11,9 +13,18 @@
from .aad_client_base import AadClientBase
from .aadclient_certificate import AadClientCertificate
from .pipeline import build_pipeline
from .._enums import RegionalAuthority


class AadClient(AadClientBase):
_LOGGER = logging.getLogger(__name__)


class AadClient(AadClientBase): # pylint:disable=client-accepts-api-version-keyword

# pylint:disable=missing-client-constructor-parameter-credential
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def __enter__(self) -> "AadClient":
self._pipeline.__enter__()
return self
Expand All @@ -27,6 +38,7 @@ def close(self) -> None:
def obtain_token_by_authorization_code(
self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs: Any
) -> AccessTokenInfo:
self._initialize_regional_authority()
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
)
Expand All @@ -35,20 +47,24 @@ def obtain_token_by_authorization_code(
def obtain_token_by_client_certificate(
self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs: Any
) -> AccessTokenInfo:
self._initialize_regional_authority()
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
return self._run_pipeline(request, **kwargs)

def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> AccessTokenInfo:
self._initialize_regional_authority()
request = self._get_client_secret_request(scopes, secret, **kwargs)
return self._run_pipeline(request, **kwargs)

def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs: Any) -> AccessTokenInfo:
self._initialize_regional_authority()
request = self._get_jwt_assertion_request(scopes, assertion, **kwargs)
return self._run_pipeline(request, **kwargs)

def obtain_token_by_refresh_token(
self, scopes: Iterable[str], refresh_token: str, **kwargs: Any
) -> AccessTokenInfo:
self._initialize_regional_authority()
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
return self._run_pipeline(request, **kwargs)

Expand All @@ -62,6 +78,37 @@ def obtain_token_on_behalf_of(
# no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL
raise NotImplementedError()

def _initialize_regional_authority(self) -> None:
# This is based on MSAL's regional authority logic.
if self._regional_authority is not False:
return

regional_authority = self._get_regional_authority_from_env()
if not regional_authority:
self._regional_authority = None
return

if regional_authority in [RegionalAuthority.AUTO_DISCOVER_REGION, "true"]:
regional_authority = self._discover_region()
if not regional_authority:
_LOGGER.info("Failed to auto-discover region. Using the non-regional authority.")
self._regional_authority = None
return

self._regional_authority = self._build_regional_authority_url(regional_authority)

def _discover_region(self) -> Optional[str]:
region = os.environ.get("REGION_NAME", "").replace(" ", "").lower()
if region:
return region
try:
request = self._get_region_discovery_request()
response = self._pipeline.run(request)
return self._process_region_discovery_response(response)
except Exception as ex: # pylint: disable=broad-except
_LOGGER.debug("Failed to discover Azure region from IMDS: %s", ex)
return None

def _build_pipeline(self, **kwargs: Any) -> Pipeline:
return build_pipeline(**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import abc
import base64
import json
import logging
import os
import time
from uuid import uuid4
from urllib.parse import urlparse
from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast

from msal import TokenCache
Expand All @@ -19,6 +22,7 @@
from .utils import get_default_authority, normalize_authority, resolve_tenant
from .aadclient_certificate import AadClientCertificate
from .._persistent_cache import _load_persistent_cache
from .._constants import EnvironmentVariables


if TYPE_CHECKING:
Expand All @@ -30,10 +34,12 @@
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
TransportType = Union[AsyncHttpTransport, HttpTransport]

_LOGGER = logging.getLogger(__name__)

JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"


class AadClientBase(abc.ABC):
class AadClientBase(abc.ABC): # pylint: disable=too-many-instance-attributes
_POST = ["POST"]

def __init__(
Expand All @@ -45,10 +51,13 @@ def __init__(
cae_cache: Optional[TokenCache] = None,
*,
additionally_allowed_tenants: Optional[List[str]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self._authority = normalize_authority(authority) if authority else get_default_authority()

# False indicates uninitialized. Actual value is str or None.
self._regional_authority: Optional[Union[str, bool]] = False

self._tenant_id = tenant_id
self._client_id = client_id
self._additionally_allowed_tenants = additionally_allowed_tenants or []
Expand Down Expand Up @@ -293,7 +302,7 @@ def _get_on_behalf_of_request(
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs: Any
**kwargs: Any,
) -> HttpRequest:
data = {
"assertion": user_assertion,
Expand Down Expand Up @@ -348,7 +357,7 @@ def _get_refresh_token_on_behalf_of_request(
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs: Any
**kwargs: Any,
) -> HttpRequest:
data = {
"grant_type": "refresh_token",
Expand All @@ -375,11 +384,43 @@ def _get_refresh_token_on_behalf_of_request(
request = self._post(data, **kwargs)
return request

def _get_region_discovery_request(self) -> HttpRequest:
url = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=2021-01-01"
request = HttpRequest("GET", url, headers={"Metadata": "true"})
return request

def _process_region_discovery_response(self, response: PipelineResponse) -> Optional[str]:
if response.http_response.status_code == 200:
region = response.http_response.text().strip()
if region:
return region
_LOGGER.warning("IMDS returned empty region")
return None

def _build_regional_authority_url(self, regional_authority: str) -> Optional[str]:
central_host = urlparse(self._authority).hostname
if not central_host:
return None

# This mirrors the regional authority logic in MSAL.
if central_host in ("login.microsoftonline.com", "login.microsoft.com", "login.windows.net", "sts.windows.net"):
regional_host = f"{regional_authority}.login.microsoft.com"
else:
regional_host = f"{regional_authority}.{central_host}"
return f"https://{regional_host}"

def _get_regional_authority_from_env(self) -> Optional[str]:
regional_authority = os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME) or os.environ.get(
"MSAL_FORCE_REGION"
) # For parity with creds that rely on MSAL, we check this var too.
return regional_authority.lower() if regional_authority else None

def _get_token_url(self, **kwargs: Any) -> str:
tenant = resolve_tenant(
self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs
)
return "/".join((self._authority, tenant, "oauth2/v2.0/token"))
authority = cast(str, self._regional_authority) if self._regional_authority else self._authority
return "/".join((authority, tenant, "oauth2/v2.0/token"))

def _post(self, data: Dict, **kwargs: Any) -> HttpRequest:
url = self._get_token_url(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import logging
import os
import time
from typing import Iterable, Optional, Union, Dict, Any

Expand All @@ -12,9 +14,12 @@
from ..._internal import AadClientCertificate
from ..._internal import AadClientBase
from ..._internal.pipeline import build_async_pipeline
from ..._enums import RegionalAuthority

Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy]

_LOGGER = logging.getLogger(__name__)


# pylint:disable=invalid-overridden-method
class AadClient(AadClientBase):
Expand All @@ -33,6 +38,7 @@ async def close(self) -> None:
async def obtain_token_by_authorization_code(
self, scopes: Iterable[str], code: str, redirect_uri: str, client_secret: Optional[str] = None, **kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
)
Expand All @@ -41,20 +47,24 @@ async def obtain_token_by_authorization_code(
async def obtain_token_by_client_certificate(
self, scopes: Iterable[str], certificate: AadClientCertificate, **kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
return await self._run_pipeline(request, stream=False, **kwargs)

async def obtain_token_by_client_secret(self, scopes: Iterable[str], secret: str, **kwargs) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_client_secret_request(scopes, secret, **kwargs)
return await self._run_pipeline(request, **kwargs)

async def obtain_token_by_jwt_assertion(self, scopes: Iterable[str], assertion: str, **kwargs) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_jwt_assertion_request(scopes, assertion, **kwargs)
return await self._run_pipeline(request, stream=False, **kwargs)

async def obtain_token_by_refresh_token(
self, scopes: Iterable[str], refresh_token: str, **kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
return await self._run_pipeline(request, **kwargs)

Expand All @@ -65,6 +75,7 @@ async def obtain_token_by_refresh_token_on_behalf_of( # pylint: disable=name-to
refresh_token: str,
**kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_refresh_token_on_behalf_of_request(
scopes, client_credential=client_credential, refresh_token=refresh_token, **kwargs
)
Expand All @@ -77,6 +88,7 @@ async def obtain_token_on_behalf_of(
user_assertion: str,
**kwargs
) -> AccessTokenInfo:
await self._initialize_regional_authority()
request = self._get_on_behalf_of_request(
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
)
Expand All @@ -85,6 +97,38 @@ async def obtain_token_on_behalf_of(
def _build_pipeline(self, **kwargs) -> AsyncPipeline:
return build_async_pipeline(**kwargs)

async def _initialize_regional_authority(self) -> None:
# This is based on MSAL's regional authority logic.
if self._regional_authority is not False:
return

regional_authority = self._get_regional_authority_from_env()
if not regional_authority:
self._regional_authority = None
return

if regional_authority in [RegionalAuthority.AUTO_DISCOVER_REGION, "true"]:
# Attempt to discover the region from IMDS
regional_authority = await self._discover_region()
if not regional_authority:
_LOGGER.info("Failed to auto-discover region. Using the non-regional authority.")
self._regional_authority = None
return

self._regional_authority = self._build_regional_authority_url(regional_authority)

async def _discover_region(self) -> Optional[str]:
region = os.environ.get("REGION_NAME", "").replace(" ", "").lower()
if region:
return region
try:
request = self._get_region_discovery_request()
response = await self._pipeline.run(request)
return self._process_region_discovery_response(response)
except Exception as ex: # pylint: disable=broad-except
_LOGGER.debug("Failed to discover Azure region from IMDS: %s", ex)
return None

async def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessTokenInfo:
# remove tenant_id and claims kwarg that could have been passed from credential's get_token method
# tenant_id is already part of `request` at this point
Expand Down
Loading