Skip to content

Commit 07395dd

Browse files
authored
[Identity] Skip IMDS probe when MI selected in DAC via env (#43080)
When the environment variable AZURE_TOKEN_CREDENTIALS is explicitly set to ManagedIdentityCredential, DefaultAzureCredential should not do IMDS probing. Signed-off-by: Paul Van Eck <[email protected]> * Update kwarg name used Signed-off-by: Paul Van Eck <[email protected]> * fix logic Signed-off-by: Paul Van Eck <[email protected]> * Change to check envvar specifically Signed-off-by: Paul Van Eck <[email protected]> * Add comments Signed-off-by: Paul Van Eck <[email protected]> --------- Signed-off-by: Paul Van Eck <[email protected]>
1 parent b0ec83c commit 07395dd

15 files changed

+273
-22
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
### Other Changes
1212

13+
- When `AZURE_TOKEN_CREDENTIALS` is set to `ManagedIdentityCredential`, `DefaultAzureCredential` now skips the IMDS endpoint probe request and directly attempts token acquisition with full retry logic, matching the behavior of using `ManagedIdentityCredential` standalone. ([#43080](https://github.com/Azure/azure-sdk-for-python/pull/43080))
1314
- Improved error messages from `ManagedIdentityCredential` to include the full error response from managed identity endpoints for better troubleshooting. ([#43231](https://github.com/Azure/azure-sdk-for-python/pull/43231))
1415

1516
## 1.25.0 (2025-09-11)

sdk/identity/azure-identity/azure/identity/_credentials/default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
172172

173173
process_timeout = kwargs.pop("process_timeout", 10)
174174
require_envvar = kwargs.pop("require_envvar", False)
175-
if require_envvar and not os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS):
175+
token_credentials_env = os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS, "").strip().lower()
176+
if require_envvar and not token_credentials_env:
176177
raise ValueError(
177178
"AZURE_TOKEN_CREDENTIALS environment variable is required but is not set or is empty. "
178179
"Set it to 'dev', 'prod', or a specific credential name."
@@ -274,6 +275,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
274275
ManagedIdentityCredential(
275276
client_id=managed_identity_client_id,
276277
_exclude_workload_identity_credential=exclude_workload_identity_credential,
278+
_enable_imds_probe=token_credentials_env != "managedidentitycredential",
277279
**kwargs,
278280
)
279281
)

sdk/identity/azure-identity/azure/identity/_credentials/imds.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def _check_forbidden_response(ex: HttpResponseError) -> None:
8282

8383
class ImdsCredential(MsalManagedIdentityClient):
8484
def __init__(self, **kwargs: Any) -> None:
85+
# If set to True/False, _enable_imds_probe forces whether or not the credential
86+
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
87+
# the credential probes only if it's part of a ChainedTokenCredential chain.
88+
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
8589
super().__init__(retry_policy_class=ImdsRetryPolicy, **dict(PIPELINE_SETTINGS, **kwargs))
8690
self._config = kwargs
8791

@@ -102,9 +106,9 @@ def close(self) -> None:
102106

103107
def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
104108

105-
if within_credential_chain.get() and not self._endpoint_available:
106-
# If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint
107-
# is available to avoid hanging for a long time if the endpoint isn't available.
109+
do_probe = self._enable_imds_probe if self._enable_imds_probe is not None else within_credential_chain.get()
110+
if do_probe and not self._endpoint_available:
111+
# Probe to see if the IMDS endpoint is available to avoid hanging for a long time if it's not.
108112
try:
109113
client = ManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **self._config))
110114
client.request_token(*scopes, connection_timeout=1, retry_total=0)

sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
user_identity_info = validate_identity_config(client_id, identity_config)
7777
self._credential: Optional[SupportsTokenInfo] = None
7878
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
79+
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
7980
managed_identity_type = None
8081

8182
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
@@ -136,7 +137,12 @@ def __init__(
136137
managed_identity_type = "IMDS"
137138
from .imds import ImdsCredential
138139

139-
self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)
140+
self._credential = ImdsCredential(
141+
client_id=client_id,
142+
identity_config=identity_config,
143+
_enable_imds_probe=self._enable_imds_probe,
144+
**kwargs,
145+
)
140146

141147
if managed_identity_type:
142148
log_msg = f"{self.__class__.__name__} will use {managed_identity_type}"

sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
144144

145145
process_timeout = kwargs.pop("process_timeout", 10)
146146
require_envvar = kwargs.pop("require_envvar", False)
147-
if require_envvar and not os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS):
147+
token_credentials_env = os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS, "").strip().lower()
148+
if require_envvar and not token_credentials_env:
148149
raise ValueError(
149150
"AZURE_TOKEN_CREDENTIALS environment variable is required but is not set or is empty. "
150151
"Set it to 'dev', 'prod', or a specific credential name."
@@ -235,6 +236,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
235236
ManagedIdentityCredential(
236237
client_id=managed_identity_client_id,
237238
_exclude_workload_identity_credential=exclude_workload_identity_credential,
239+
_enable_imds_probe=token_credentials_env != "managedidentitycredential",
238240
**kwargs,
239241
)
240242
)

sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class ImdsCredential(AsyncContextManager, GetTokenMixin):
4545
def __init__(self, **kwargs: Any) -> None:
4646
super().__init__()
4747

48+
# If set to True/False, _enable_imds_probe forces whether or not the credential
49+
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
50+
# the credential probes only if it's part of a ChainedTokenCredential chain.
51+
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
4852
kwargs["retry_policy_class"] = AsyncImdsRetryPolicy
4953
self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs))
5054
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:
@@ -65,9 +69,9 @@ async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional
6569

6670
async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
6771

68-
if within_credential_chain.get() and not self._endpoint_available:
69-
# If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint
70-
# is available to avoid hanging for a long time if the endpoint isn't available.
72+
do_probe = self._enable_imds_probe if self._enable_imds_probe is not None else within_credential_chain.get()
73+
if do_probe and not self._endpoint_available:
74+
# Probe to see if the IMDS endpoint is available to avoid hanging for a long time if it's not.
7175
try:
7276
await self._client.request_token(*scopes, connection_timeout=1, retry_total=0)
7377
self._endpoint_available = True

sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
user_identity_info = validate_identity_config(client_id, identity_config)
5050
self._credential: Optional[AsyncSupportsTokenInfo] = None
5151
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
52+
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
5253
managed_identity_type = None
5354
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
5455
if os.environ.get(EnvironmentVariables.IDENTITY_HEADER):
@@ -108,7 +109,12 @@ def __init__(
108109
managed_identity_type = "IMDS"
109110
from .imds import ImdsCredential
110111

111-
self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)
112+
self._credential = ImdsCredential(
113+
client_id=client_id,
114+
identity_config=identity_config,
115+
_enable_imds_probe=self._enable_imds_probe,
116+
**kwargs,
117+
)
112118

113119
if managed_identity_type:
114120
log_msg = f"{self.__class__.__name__} will use {managed_identity_type}"

sdk/identity/azure-identity/tests/test_default.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,11 @@ def test_default_credential_shared_cache_use(mock_credential):
313313
def test_managed_identity_client_id():
314314
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""
315315

316-
expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False}
316+
expected_args = {
317+
"client_id": "the-client",
318+
"_exclude_workload_identity_credential": False,
319+
"_enable_imds_probe": True,
320+
}
317321

318322
with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
319323
DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])

sdk/identity/azure-identity/tests/test_default_async.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,11 @@ async def test_default_credential_shared_cache_use():
262262
def test_managed_identity_client_id():
263263
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""
264264

265-
expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False}
265+
expected_args = {
266+
"client_id": "the-client",
267+
"_exclude_workload_identity_credential": False,
268+
"_enable_imds_probe": True,
269+
}
266270

267271
with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
268272
DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])

sdk/identity/azure-identity/tests/test_imds_credential.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
IMDS_AUTHORITY,
1515
PIPELINE_SETTINGS,
1616
)
17-
from azure.identity._internal.utils import within_credential_chain
1817
from azure.core.pipeline import PipelineResponse
1918
from azure.core.pipeline.policies import RetryPolicy
2019
from azure.core.pipeline.transport import HttpRequest, HttpResponse
@@ -109,7 +108,7 @@ def test_user_assigned_tenant_id(self, recorded_test, get_token_method):
109108
assert isinstance(token.expires_on, int)
110109

111110
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
112-
def test_managed_identity_aci_probe(self, get_token_method):
111+
def test_enable_imds_probe(self, get_token_method):
113112
access_token = "****"
114113
expires_on = 42
115114
expected_token = access_token
@@ -140,11 +139,9 @@ def test_managed_identity_aci_probe(self, get_token_method):
140139
),
141140
],
142141
)
143-
within_credential_chain.set(True)
144-
credential = ImdsCredential(transport=transport)
142+
credential = ImdsCredential(transport=transport, _enable_imds_probe=True)
145143
token = getattr(credential, get_token_method)(scope)
146144
assert token.token == expected_token
147-
within_credential_chain.set(False)
148145

149146
def test_imds_credential_uses_custom_retry_policy(self):
150147
credential = ImdsCredential()

0 commit comments

Comments
 (0)