Skip to content

Commit 548ea62

Browse files
xiangyan99pvaneck
andauthored
Add cloud shell validation (Azure#37278)
* Add cloud shell validation * update * update * update * update * update * Update sdk/identity/azure-identity/CHANGELOG.md Co-authored-by: Paul Van Eck <[email protected]> * rename tests --------- Co-authored-by: Paul Van Eck <[email protected]>
1 parent 04e0fda commit 548ea62

File tree

9 files changed

+84
-13
lines changed

9 files changed

+84
-13
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
### Other Changes
1212

1313
- Added identity config validation to `ManagedIdentityCredential` to avoid non-deterministic states (e.g. both `resource_id` and `object_id` are specified). ([#36950](https://github.com/Azure/azure-sdk-for-python/pull/36950))
14+
- Additional validation was added for `ManagedIdentityCredential` in Azure Cloud Shell environments. ([#36438](https://github.com/Azure/azure-sdk-for-python/issues/36438))
1415

1516
## 1.18.0b2 (2024-08-09)
1617

sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,33 @@
44
# ------------------------------------
55
import functools
66
import os
7-
from typing import Any, Optional, Dict
7+
from typing import Any, Optional, Dict, Mapping
88

99
from azure.core.pipeline.transport import HttpRequest
1010

1111
from .._constants import EnvironmentVariables
12+
from .._internal import within_dac
1213
from .._internal.managed_identity_client import ManagedIdentityClient
1314
from .._internal.managed_identity_base import ManagedIdentityBase
1415

1516

17+
def validate_client_id_and_config(client_id: Optional[str], identity_config: Optional[Mapping[str, str]]) -> None:
18+
if within_dac.get():
19+
return
20+
if client_id:
21+
raise ValueError("client_id should not be set for cloud shell managed identity.")
22+
if identity_config:
23+
valid_keys = {"object_id", "resource_id", "client_id"}
24+
if len(identity_config.keys() & valid_keys) > 0:
25+
raise ValueError(f"identity_config must not contain the following keys: {', '.join(valid_keys)}")
26+
27+
1628
class CloudShellCredential(ManagedIdentityBase):
1729
def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
30+
client_id = kwargs.get("client_id")
31+
identity_config = kwargs.get("identity_config")
32+
validate_client_id_and_config(client_id, identity_config)
33+
1834
url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
1935
if url:
2036
return ManagedIdentityClient(

sdk/identity/azure-identity/azure/identity/_credentials/default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
145145
exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False)
146146

147147
credentials: List["TokenCredential"] = []
148+
within_dac.set(True)
148149
if not exclude_environment_credential:
149150
credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs))
150151
if not exclude_workload_identity_credential:
@@ -192,7 +193,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
192193
)
193194
else:
194195
credentials.append(InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id, **kwargs))
195-
196+
within_dac.set(False)
196197
super(DefaultAzureCredential, self).__init__(*credentials)
197198

198199
def get_token(

sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
from .._internal.managed_identity_base import AsyncManagedIdentityBase
1010
from .._internal.managed_identity_client import AsyncManagedIdentityClient
1111
from ..._constants import EnvironmentVariables
12-
from ..._credentials.cloud_shell import _get_request
12+
from ..._credentials.cloud_shell import _get_request, validate_client_id_and_config
1313

1414

1515
class CloudShellCredential(AsyncManagedIdentityBase):
1616
def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]:
17+
client_id = kwargs.get("client_id")
18+
identity_config = kwargs.get("identity_config")
19+
validate_client_id_and_config(client_id, identity_config)
20+
1721
url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
1822
if url:
1923
return AsyncManagedIdentityClient(

sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class DefaultAzureCredential(ChainedTokenCredential):
8989
:caption: Create a DefaultAzureCredential.
9090
"""
9191

92-
def __init__(self, **kwargs: Any) -> None:
92+
def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statements, too-many-locals
9393
if "tenant_id" in kwargs:
9494
raise TypeError("'tenant_id' is not supported in DefaultAzureCredential.")
9595

@@ -135,6 +135,7 @@ def __init__(self, **kwargs: Any) -> None:
135135
exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False)
136136

137137
credentials = [] # type: List[AsyncTokenCredential]
138+
within_dac.set(True)
138139
if not exclude_environment_credential:
139140
credentials.append(EnvironmentCredential(authority=authority, _within_dac=True, **kwargs))
140141
if not exclude_workload_identity_credential:
@@ -173,7 +174,7 @@ def __init__(self, **kwargs: Any) -> None:
173174
credentials.append(AzurePowerShellCredential(process_timeout=process_timeout))
174175
if not exclude_developer_cli_credential:
175176
credentials.append(AzureDeveloperCliCredential(process_timeout=process_timeout))
176-
177+
within_dac.set(False)
177178
super().__init__(*credentials)
178179

179180
async def get_token(

sdk/identity/azure-identity/tests/test_default.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,13 @@ def test_unexpected_kwarg():
412412
def test_error_tenant_id():
413413
with pytest.raises(TypeError):
414414
DefaultAzureCredential(tenant_id="foo")
415+
416+
417+
def test_validate_cloud_shell_credential_in_dac():
418+
MANAGED_IDENTITY_ENVIRON = "azure.identity._credentials.managed_identity.os.environ"
419+
with patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True):
420+
DefaultAzureCredential()
421+
DefaultAzureCredential(managed_identity_client_id="foo")
422+
DefaultAzureCredential(identity_config={"client_id": "foo"})
423+
DefaultAzureCredential(identity_config={"object_id": "foo"})
424+
DefaultAzureCredential(identity_config={"resource_id": "foo"})

sdk/identity/azure-identity/tests/test_default_async.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,13 @@ def test_unexpected_kwarg():
326326
def test_error_tenant_id():
327327
with pytest.raises(TypeError):
328328
DefaultAzureCredential(tenant_id="foo")
329+
330+
331+
def test_validate_cloud_shell_credential_in_dac():
332+
MANAGED_IDENTITY_ENVIRON = "azure.identity.aio._credentials.managed_identity.os.environ"
333+
with patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True):
334+
DefaultAzureCredential()
335+
DefaultAzureCredential(managed_identity_client_id="foo")
336+
DefaultAzureCredential(identity_config={"client_id": "foo"})
337+
DefaultAzureCredential(identity_config={"object_id": "foo"})
338+
DefaultAzureCredential(identity_config={"resource_id": "foo"})

sdk/identity/azure-identity/tests/test_managed_identity.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,11 @@ def test_azure_ml_tenant_id():
309309
assert token.expires_on == expected_token.expires_on
310310

311311

312-
def test_cloud_shell_user_assigned_identity():
312+
def test_cloud_shell_identity_config():
313313
"""Cloud Shell environment: only MSI_ENDPOINT set"""
314314

315315
expected_token = "****"
316316
expires_on = 42
317-
client_id = "some-guid"
318317
endpoint = "http://localhost:42/token"
319318
scope = "scope"
320319
param_name, param_value = "foo", "bar"
@@ -325,7 +324,7 @@ def test_cloud_shell_user_assigned_identity():
325324
base_url=endpoint,
326325
method="POST",
327326
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
328-
required_data={"client_id": client_id, "resource": scope},
327+
required_data={"resource": scope},
329328
),
330329
Request(
331330
base_url=endpoint,
@@ -350,7 +349,7 @@ def test_cloud_shell_user_assigned_identity():
350349
)
351350

352351
with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True):
353-
token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope)
352+
token = ManagedIdentityCredential(transport=transport).get_token(scope)
354353
assert token.token == expected_token
355354
assert token.expires_on == expires_on
356355

@@ -965,3 +964,18 @@ def test_validate_identity_config():
965964
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
966965
with pytest.raises(ValueError):
967966
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})
967+
968+
969+
def test_validate_cloud_shell_credential():
970+
with mock.patch.dict(
971+
MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True
972+
):
973+
ManagedIdentityCredential()
974+
with pytest.raises(ValueError):
975+
ManagedIdentityCredential(client_id="foo")
976+
with pytest.raises(ValueError):
977+
ManagedIdentityCredential(identity_config={"client_id": "foo"})
978+
with pytest.raises(ValueError):
979+
ManagedIdentityCredential(identity_config={"object_id": "foo"})
980+
with pytest.raises(ValueError):
981+
ManagedIdentityCredential(identity_config={"resource_id": "foo"})

sdk/identity/azure-identity/tests/test_managed_identity_async.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,11 @@ async def test_azure_ml_tenant_id():
313313

314314

315315
@pytest.mark.asyncio
316-
async def test_cloud_shell_user_assigned_identity():
316+
async def test_cloud_shell_identity_config():
317317
"""Cloud Shell environment: only MSI_ENDPOINT set"""
318318

319319
expected_token = "****"
320320
expires_on = 42
321-
client_id = "some-guid"
322321
endpoint = "http://localhost:42/token"
323322
scope = "scope"
324323
param_name, param_value = "foo", "bar"
@@ -329,7 +328,7 @@ async def test_cloud_shell_user_assigned_identity():
329328
base_url=endpoint,
330329
method="POST",
331330
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
332-
required_data={"client_id": client_id, "resource": scope},
331+
required_data={"resource": scope},
333332
),
334333
Request(
335334
base_url=endpoint,
@@ -354,7 +353,7 @@ async def test_cloud_shell_user_assigned_identity():
354353
)
355354

356355
with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True):
357-
credential = ManagedIdentityCredential(client_id=client_id, transport=transport)
356+
credential = ManagedIdentityCredential(transport=transport)
358357
token = await credential.get_token(scope)
359358
assert token.token == expected_token
360359
assert token.expires_on == expires_on
@@ -1234,3 +1233,18 @@ def test_validate_identity_config():
12341233
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
12351234
with pytest.raises(ValueError):
12361235
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})
1236+
1237+
1238+
def test_validate_cloud_shell_credential():
1239+
with mock.patch.dict(
1240+
MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True
1241+
):
1242+
ManagedIdentityCredential()
1243+
with pytest.raises(ValueError):
1244+
ManagedIdentityCredential(client_id="foo")
1245+
with pytest.raises(ValueError):
1246+
ManagedIdentityCredential(identity_config={"client_id": "foo"})
1247+
with pytest.raises(ValueError):
1248+
ManagedIdentityCredential(identity_config={"object_id": "foo"})
1249+
with pytest.raises(ValueError):
1250+
ManagedIdentityCredential(identity_config={"resource_id": "foo"})

0 commit comments

Comments
 (0)