Skip to content

Commit 513989e

Browse files
authored
[Identity] Validate identity config for MICredential (#36950)
ManagedIdentityCredential now validates the inputs for client_id and identity_config to ensure no mutually exclusive values are given. Signed-off-by: Paul Van Eck <[email protected]>
1 parent 3ae4c8d commit 513989e

File tree

7 files changed

+121
-46
lines changed

7 files changed

+121
-46
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
### Other Changes
1212

13+
- Added identity config validation to `ManagedIdentityCredential` to avoid non-deterministic states (e.g. both `resource_id` and `object_id` are specified). ([#36950](https://github.com/Azure/azure-sdk-for-python/pull/36950))
14+
1315
## 1.18.0b2 (2024-08-09)
1416

1517
### Features Added

sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# ------------------------------------
55
import logging
66
import os
7-
from typing import Optional, TYPE_CHECKING, Any
7+
from typing import Optional, TYPE_CHECKING, Any, Mapping
88

99
from azure.core.credentials import AccessToken
1010
from .. import CredentialUnavailableError
@@ -17,6 +17,22 @@
1717
_LOGGER = logging.getLogger(__name__)
1818

1919

20+
def validate_identity_config(client_id: Optional[str], identity_config: Optional[Mapping[str, str]]) -> None:
21+
if identity_config:
22+
if client_id:
23+
if any(key in identity_config for key in ("object_id", "resource_id", "client_id")):
24+
raise ValueError(
25+
"identity_config must not contain 'object_id', 'resource_id', or 'client_id' when 'client_id' is "
26+
"provided as a keyword argument."
27+
)
28+
# Only one of these keys should be present if one is present.
29+
valid_keys = {"object_id", "resource_id", "client_id"}
30+
if len(identity_config.keys() & valid_keys) > 1:
31+
raise ValueError(
32+
f"identity_config must not contain more than one of the following keys: {', '.join(valid_keys)}"
33+
)
34+
35+
2036
class ManagedIdentityCredential:
2137
"""Authenticates with an Azure managed identity in any hosting environment which supports managed identities.
2238
@@ -42,59 +58,66 @@ class ManagedIdentityCredential:
4258
:caption: Create a ManagedIdentityCredential.
4359
"""
4460

45-
def __init__(self, **kwargs: Any) -> None:
46-
self._credential = None # type: Optional[TokenCredential]
61+
def __init__(
62+
self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any
63+
) -> None:
64+
validate_identity_config(client_id, identity_config)
65+
self._credential: Optional[TokenCredential] = None
4766
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
4867
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
4968
if os.environ.get(EnvironmentVariables.IDENTITY_HEADER):
5069
if os.environ.get(EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT):
5170
_LOGGER.info("%s will use Service Fabric managed identity", self.__class__.__name__)
5271
from .service_fabric import ServiceFabricCredential
5372

54-
self._credential = ServiceFabricCredential(**kwargs)
73+
self._credential = ServiceFabricCredential(
74+
client_id=client_id, identity_config=identity_config, **kwargs
75+
)
5576
else:
5677
_LOGGER.info("%s will use App Service managed identity", self.__class__.__name__)
5778
from .app_service import AppServiceCredential
5879

59-
self._credential = AppServiceCredential(**kwargs)
80+
self._credential = AppServiceCredential(
81+
client_id=client_id, identity_config=identity_config, **kwargs
82+
)
6083
elif os.environ.get(EnvironmentVariables.IMDS_ENDPOINT):
6184
_LOGGER.info("%s will use Azure Arc managed identity", self.__class__.__name__)
6285
from .azure_arc import AzureArcCredential
6386

64-
self._credential = AzureArcCredential(**kwargs)
87+
self._credential = AzureArcCredential(client_id=client_id, identity_config=identity_config, **kwargs)
6588
elif os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
6689
if os.environ.get(EnvironmentVariables.MSI_SECRET):
6790
_LOGGER.info("%s will use Azure ML managed identity", self.__class__.__name__)
6891
from .azure_ml import AzureMLCredential
6992

70-
self._credential = AzureMLCredential(**kwargs)
93+
self._credential = AzureMLCredential(client_id=client_id, identity_config=identity_config, **kwargs)
7194
else:
7295
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
7396
from .cloud_shell import CloudShellCredential
7497

75-
self._credential = CloudShellCredential(**kwargs)
98+
self._credential = CloudShellCredential(client_id=client_id, identity_config=identity_config, **kwargs)
7699
elif (
77100
all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS)
78101
and not exclude_workload_identity
79102
):
80103
_LOGGER.info("%s will use workload identity", self.__class__.__name__)
81104
from .workload_identity import WorkloadIdentityCredential
82105

83-
client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
84-
if not client_id:
106+
workload_client_id = client_id or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
107+
if not workload_client_id:
85108
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')
86109

87110
self._credential = WorkloadIdentityCredential(
88111
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
89-
client_id=client_id,
112+
client_id=workload_client_id,
90113
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
91-
**kwargs
114+
**kwargs,
92115
)
93116
else:
94117
from .imds import ImdsCredential
95118

96119
_LOGGER.info("%s will use IMDS", self.__class__.__name__)
97-
self._credential = ImdsCredential(**kwargs)
120+
self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)
98121

99122
def __enter__(self) -> "ManagedIdentityCredential":
100123
if self._credential:

sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
from typing import Any, Optional, Dict, cast, Union
5+
from typing import Any, Optional, Dict, cast, Union, Mapping
66
import abc
77
import time
88
import logging
@@ -23,10 +23,12 @@ class MsalManagedIdentityClient(abc.ABC): # pylint:disable=client-accepts-api-v
2323
"""Base class for managed identity client wrapping MSAL ManagedIdentityClient."""
2424

2525
# pylint:disable=missing-client-constructor-parameter-credential
26-
def __init__(self, **kwargs: Any) -> None:
27-
self._settings = kwargs
26+
def __init__(
27+
self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any
28+
) -> None:
29+
self._settings = {"client_id": client_id, "identity_config": identity_config or {}}
2830
self._client = MsalClient(**kwargs)
29-
managed_identity = self.get_managed_identity(**kwargs)
31+
managed_identity = self.get_managed_identity()
3032
self._msal_client = msal.ManagedIdentityClient(managed_identity, http_client=self._client)
3133

3234
def __enter__(self) -> "MsalManagedIdentityClient":
@@ -56,20 +58,17 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:
5658
error_message = self.get_unavailable_message(error_desc)
5759
raise CredentialUnavailableError(error_message)
5860

59-
def get_managed_identity(
60-
self, **kwargs: Any
61-
) -> Union[msal.UserAssignedManagedIdentity, msal.SystemAssignedManagedIdentity]:
61+
def get_managed_identity(self) -> Union[msal.UserAssignedManagedIdentity, msal.SystemAssignedManagedIdentity]:
6262
"""
6363
Get the managed identity configuration.
64-
:keyword str client_id: The client ID of the user-assigned managed identity.
65-
:keyword dict identity_config: The identity configuration.
6664
6765
:rtype: msal.UserAssignedManagedIdentity or msal.SystemAssignedManagedIdentity
6866
:return: The managed identity configuration.
6967
"""
70-
if "client_id" in kwargs and kwargs["client_id"]:
71-
return msal.UserAssignedManagedIdentity(client_id=kwargs["client_id"])
72-
identity_config = kwargs.pop("identity_config", None) or {}
68+
69+
if "client_id" in self._settings and self._settings["client_id"]:
70+
return msal.UserAssignedManagedIdentity(client_id=self._settings["client_id"])
71+
identity_config = cast(Dict, self._settings.get("identity_config")) or {}
7372
if "client_id" in identity_config and identity_config["client_id"]:
7473
return msal.UserAssignedManagedIdentity(client_id=identity_config["client_id"])
7574
if "resource_id" in identity_config and identity_config["resource_id"]:
@@ -154,5 +153,5 @@ def __getstate__(self) -> Dict[str, Any]: # pylint:disable=client-method-name-n
154153
def __setstate__(self, state: Dict[str, Any]) -> None: # pylint:disable=client-method-name-no-double-underscore
155154
self.__dict__.update(state)
156155
# Re-create the unpickable entries
157-
managed_identity = self.get_managed_identity(**self._settings)
156+
managed_identity = self.get_managed_identity()
158157
self._msal_client = msal.ManagedIdentityClient(managed_identity, http_client=self._client)

sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# ------------------------------------
55
import logging
66
import os
7-
from typing import TYPE_CHECKING, Optional, Any
7+
from typing import TYPE_CHECKING, Optional, Any, Mapping
88

99
from azure.core.credentials import AccessToken
1010
from .._internal import AsyncContextManager
1111
from .._internal.decorators import log_get_token_async
1212
from ... import CredentialUnavailableError
1313
from ..._constants import EnvironmentVariables
14+
from ..._credentials.managed_identity import validate_identity_config
1415

1516
if TYPE_CHECKING:
1617
from azure.core.credentials_async import AsyncTokenCredential
@@ -43,8 +44,11 @@ class ManagedIdentityCredential(AsyncContextManager):
4344
:caption: Create a ManagedIdentityCredential.
4445
"""
4546

46-
def __init__(self, **kwargs: Any) -> None:
47-
self._credential = None # type: Optional[AsyncTokenCredential]
47+
def __init__(
48+
self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any
49+
) -> None:
50+
validate_identity_config(client_id, identity_config)
51+
self._credential: Optional[AsyncTokenCredential] = None
4852
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
4953

5054
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
@@ -53,55 +57,54 @@ def __init__(self, **kwargs: Any) -> None:
5357
_LOGGER.info("%s will use Service Fabric managed identity", self.__class__.__name__)
5458
from .service_fabric import ServiceFabricCredential
5559

56-
self._credential = ServiceFabricCredential(**kwargs)
60+
self._credential = ServiceFabricCredential(
61+
client_id=client_id, identity_config=identity_config, **kwargs
62+
)
5763
else:
5864
_LOGGER.info("%s will use App Service managed identity", self.__class__.__name__)
5965
from .app_service import AppServiceCredential
6066

61-
self._credential = AppServiceCredential(**kwargs)
67+
self._credential = AppServiceCredential(
68+
client_id=client_id, identity_config=identity_config, **kwargs
69+
)
6270
elif os.environ.get(EnvironmentVariables.IMDS_ENDPOINT):
6371
_LOGGER.info("%s will use Azure Arc managed identity", self.__class__.__name__)
6472
from .azure_arc import AzureArcCredential
6573

66-
self._credential = AzureArcCredential(**kwargs)
67-
else:
68-
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
69-
from .cloud_shell import CloudShellCredential
70-
71-
self._credential = CloudShellCredential(**kwargs)
74+
self._credential = AzureArcCredential(client_id=client_id, identity_config=identity_config, **kwargs)
7275
elif os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
7376
if os.environ.get(EnvironmentVariables.MSI_SECRET):
7477
_LOGGER.info("%s will use Azure ML managed identity", self.__class__.__name__)
7578
from .azure_ml import AzureMLCredential
7679

77-
self._credential = AzureMLCredential(**kwargs)
80+
self._credential = AzureMLCredential(client_id=client_id, identity_config=identity_config, **kwargs)
7881
else:
7982
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
8083
from .cloud_shell import CloudShellCredential
8184

82-
self._credential = CloudShellCredential(**kwargs)
85+
self._credential = CloudShellCredential(client_id=client_id, identity_config=identity_config, **kwargs)
8386
elif (
8487
all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS)
8588
and not exclude_workload_identity
8689
):
8790
_LOGGER.info("%s will use workload identity", self.__class__.__name__)
8891
from .workload_identity import WorkloadIdentityCredential
8992

90-
client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
91-
if not client_id:
93+
workload_client_id = client_id or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
94+
if not workload_client_id:
9295
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')
9396

9497
self._credential = WorkloadIdentityCredential(
9598
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
96-
client_id=client_id,
99+
client_id=workload_client_id,
97100
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
98101
**kwargs
99102
)
100103
else:
101104
from .imds import ImdsCredential
102105

103106
_LOGGER.info("%s will use IMDS", self.__class__.__name__)
104-
self._credential = ImdsCredential(**kwargs)
107+
self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)
105108

106109
async def __aenter__(self) -> "ManagedIdentityCredential":
107110
if self._credential:

sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import abc
6-
from typing import Any, cast, Optional, TypeVar
6+
from types import TracebackType
7+
from typing import Any, cast, Optional, TypeVar, Type
78

89
from azure.core.credentials import AccessToken
910
from . import AsyncContextManager
@@ -34,9 +35,14 @@ async def __aenter__(self: T) -> T:
3435
await self._client.__aenter__()
3536
return self
3637

37-
async def __aexit__(self, *args):
38+
async def __aexit__(
39+
self,
40+
exc_type: Optional[Type[BaseException]] = None,
41+
exc_value: Optional[BaseException] = None,
42+
traceback: Optional[TracebackType] = None,
43+
) -> None:
3844
if self._client:
39-
await self._client.__aexit__(*args)
45+
await self._client.__aexit__(exc_type, exc_value, traceback)
4046

4147
async def close(self) -> None:
4248
await self.__aexit__()

sdk/identity/azure-identity/tests/test_managed_identity.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,3 +944,24 @@ def test_token_exchange_tenant_id(tmpdir):
944944
credential = ManagedIdentityCredential(transport=transport)
945945
token = credential.get_token(scope, tenant_id="tenant_id")
946946
assert token.token == access_token
947+
948+
949+
def test_validate_identity_config():
950+
ManagedIdentityCredential()
951+
ManagedIdentityCredential(client_id="foo")
952+
ManagedIdentityCredential(identity_config={"foo": "bar"})
953+
ManagedIdentityCredential(identity_config={"client_id": "foo"})
954+
ManagedIdentityCredential(identity_config={"object_id": "foo"})
955+
ManagedIdentityCredential(identity_config={"resource_id": "foo"})
956+
ManagedIdentityCredential(identity_config={"foo": "bar"}, client_id="foo")
957+
958+
with pytest.raises(ValueError):
959+
ManagedIdentityCredential(identity_config={"client_id": "foo"}, client_id="foo")
960+
with pytest.raises(ValueError):
961+
ManagedIdentityCredential(identity_config={"object_id": "bar"}, client_id="bar")
962+
with pytest.raises(ValueError):
963+
ManagedIdentityCredential(identity_config={"resource_id": "bar"}, client_id="bar")
964+
with pytest.raises(ValueError):
965+
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
966+
with pytest.raises(ValueError):
967+
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})

sdk/identity/azure-identity/tests/test_managed_identity_async.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,3 +1213,24 @@ async def test_token_exchange_tenant_id(tmpdir):
12131213
credential = ManagedIdentityCredential(transport=transport)
12141214
token = await credential.get_token(scope, tenant_id="tenant_id")
12151215
assert token.token == access_token
1216+
1217+
1218+
def test_validate_identity_config():
1219+
ManagedIdentityCredential()
1220+
ManagedIdentityCredential(client_id="foo")
1221+
ManagedIdentityCredential(identity_config={"foo": "bar"})
1222+
ManagedIdentityCredential(identity_config={"client_id": "foo"})
1223+
ManagedIdentityCredential(identity_config={"object_id": "foo"})
1224+
ManagedIdentityCredential(identity_config={"resource_id": "foo"})
1225+
ManagedIdentityCredential(identity_config={"foo": "bar"}, client_id="foo")
1226+
1227+
with pytest.raises(ValueError):
1228+
ManagedIdentityCredential(identity_config={"client_id": "foo"}, client_id="foo")
1229+
with pytest.raises(ValueError):
1230+
ManagedIdentityCredential(identity_config={"object_id": "bar"}, client_id="bar")
1231+
with pytest.raises(ValueError):
1232+
ManagedIdentityCredential(identity_config={"resource_id": "bar"}, client_id="bar")
1233+
with pytest.raises(ValueError):
1234+
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
1235+
with pytest.raises(ValueError):
1236+
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})

0 commit comments

Comments
 (0)