Skip to content

Commit 8547967

Browse files
committed
Simplify 3 managed identity classes into 1
1 parent cdfd8f2 commit 8547967

File tree

3 files changed

+47
-61
lines changed

3 files changed

+47
-61
lines changed

msal/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,5 @@
3333
)
3434
from .oauth2cli.oidc import Prompt
3535
from .token_cache import TokenCache, SerializableTokenCache
36-
from .imds import (
37-
SystemAssignedManagedIdentity,
38-
UserAssignedManagedIdentity,
39-
ManagedIdentityClient,
40-
)
36+
from .imds import ManagedIdentity, ManagedIdentityClient
4137

msal/imds.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,83 +21,75 @@
2121
logger = logging.getLogger(__name__)
2222

2323
class ManagedIdentity(UserDict):
24+
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
25+
to acquire token for the specified managed identity.
26+
"""
2427
# The key names used in config dict
25-
ID_TYPE = "ManagedIdentityIdType"
28+
ID_TYPE = "ManagedIdentityIdType" # Contains keyword ManagedIdentity so its json equivalent will be more readable
2629
ID = "Id"
27-
def __init__(self, identifier=None, id_type=None):
28-
super(ManagedIdentity, self).__init__({
29-
self.ID_TYPE: id_type,
30-
self.ID: identifier,
31-
})
32-
3330

34-
class UserAssignedManagedIdentity(ManagedIdentity):
35-
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
36-
to acquire token for user-assigned managed identity.
37-
"""
31+
# Valid values for key ID_TYPE
3832
CLIENT_ID = "ClientId"
3933
RESOURCE_ID = "ResourceId"
4034
OBJECT_ID = "ObjectId"
35+
SYSTEM_ASSIGNED = "SystemAssigned"
36+
4137
_types_mapping = { # Maps type name in configuration to type name on wire
4238
CLIENT_ID: "client_id",
4339
RESOURCE_ID: "mi_res_id",
4440
OBJECT_ID: "object_id",
4541
}
46-
def __init__(self, identifier, id_type):
47-
"""Do not use this contructor. Use the following factory methods instead."""
48-
if id_type not in self._types_mapping:
49-
raise ValueError("id_type only accepts one of: {}".format(
50-
list(self._types_mapping)))
51-
super(UserAssignedManagedIdentity, self).__init__(
52-
identifier=identifier,
53-
id_type=id_type,
54-
)
5542

5643
@classmethod
57-
def from_client_id(cls, identifier):
58-
"""Construct a UserAssignedManagedIdentity instance from a client id.
44+
def system_assigned(cls):
45+
"""Construct a system-assigned managed identity.
5946
60-
The outcome will be equivalent to::
47+
The outcome is equivalent to::
48+
49+
{"ManagedIdentityIdType": "SystemAssigned", "Id": None}
50+
"""
51+
return ManagedIdentity(id_type=cls.SYSTEM_ASSIGNED)
52+
53+
@classmethod
54+
def is_system_assigned(cls, unknown):
55+
return isinstance(unknown, dict) and unknown.get(cls.ID_TYPE) == cls.SYSTEM_ASSIGNED
56+
57+
@classmethod
58+
def user_assigned_client_id(cls, identifier):
59+
"""Construct a ``ManagedIdentity`` instance from a user-assigned client id.
60+
61+
The outcome is equivalent to::
6162
6263
{"ManagedIdentityIdType": "ClientId", "Id": "foo"}
6364
"""
64-
return UserAssignedManagedIdentity(identifier, cls.CLIENT_ID)
65+
return ManagedIdentity(identifier=identifier, id_type=cls.CLIENT_ID)
6566

6667
@classmethod
67-
def from_resource_id(cls, identifier):
68-
"""Construct a UserAssignedManagedIdentity instance from a resource id.
68+
def user_assigned_resource_id(cls, identifier):
69+
"""Construct a ``ManagedIdentity`` instance from a user-assigned resource id.
6970
70-
The outcome will be equivalent to::
71+
The outcome is equivalent to::
7172
7273
{"ManagedIdentityIdType": "ResourceId", "Id": "foo"}
7374
"""
74-
return UserAssignedManagedIdentity(identifier, cls.RESOURCE_ID)
75+
return ManagedIdentity(identifier=identifier, id_type=cls.RESOURCE_ID)
7576

7677
@classmethod
77-
def from_object_id(cls, identifier):
78-
"""Construct a UserAssignedManagedIdentity instance from an object id.
78+
def user_assigned_object_id(cls, identifier):
79+
"""Construct a ManagedIdentity instance from a user-assigned object id.
7980
8081
The outcome will be equivalent to::
8182
8283
{"ManagedIdentityIdType": "ObjectId", "Id": "foo"}
8384
"""
84-
return UserAssignedManagedIdentity(identifier, cls.OBJECT_ID)
85-
86-
87-
class SystemAssignedManagedIdentity(ManagedIdentity):
88-
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
89-
to acquire token for system-assigned managed identity.
85+
return ManagedIdentity(identifier=identifier, id_type=cls.OBJECT_ID)
9086

91-
By design, an instance of this class is equivalent to::
92-
93-
{"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": None}
94-
"""
95-
def __init__(self):
96-
super(SystemAssignedManagedIdentity, self).__init__(
97-
id_type="SystemAssignedManagedIdentity", # As of this writing,
98-
# It can be any value other than
99-
# UserAssignedManagedIdentity._types_mapping's key names
100-
)
87+
def __init__(self, identifier=None, id_type=None):
88+
# Undocumented. Use other class methods instead.
89+
super(ManagedIdentity, self).__init__({
90+
self.ID_TYPE: id_type,
91+
self.ID: identifier,
92+
})
10193

10294

10395
def _scope_to_resource(scope): # This is an experimental reasonable-effort approach
@@ -136,7 +128,7 @@ def _obtain_token(http_client, managed_identity, resource):
136128

137129

138130
def _adjust_param(params, managed_identity):
139-
id_name = UserAssignedManagedIdentity._types_mapping.get(
131+
id_name = ManagedIdentity._types_mapping.get(
140132
managed_identity.get(ManagedIdentity.ID_TYPE))
141133
if id_name:
142134
params[id_name] = managed_identity[ManagedIdentity.ID]

tests/test_mi.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,23 @@
99
import requests
1010

1111
from tests.http_client import MinimalResponse
12-
from msal import (
13-
TokenCache,
14-
SystemAssignedManagedIdentity, UserAssignedManagedIdentity, ManagedIdentityClient)
12+
from msal import TokenCache, ManagedIdentity, ManagedIdentityClient
1513

1614

1715
class ManagedIdentityTestCase(unittest.TestCase):
1816
def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_from_file_or_env_var(self):
1917
self.assertEqual(
20-
UserAssignedManagedIdentity.from_client_id("foo"),
18+
ManagedIdentity.user_assigned_client_id("foo"),
2119
{"ManagedIdentityIdType": "ClientId", "Id": "foo"})
2220
self.assertEqual(
23-
UserAssignedManagedIdentity.from_resource_id("foo"),
21+
ManagedIdentity.user_assigned_resource_id("foo"),
2422
{"ManagedIdentityIdType": "ResourceId", "Id": "foo"})
2523
self.assertEqual(
26-
UserAssignedManagedIdentity.from_object_id("foo"),
24+
ManagedIdentity.user_assigned_object_id("foo"),
2725
{"ManagedIdentityIdType": "ObjectId", "Id": "foo"})
2826
self.assertEqual(
29-
SystemAssignedManagedIdentity(),
30-
{"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": None})
27+
ManagedIdentity.system_assigned(),
28+
{"ManagedIdentityIdType": "SystemAssigned", "Id": None})
3129

3230

3331
class ClientTestCase(unittest.TestCase):
@@ -38,7 +36,7 @@ def setUp(self):
3836
requests.Session(),
3937
{ # Here we test it with the raw dict form, to test that
4038
# the client has no hard dependency on ManagedIdentity object
41-
"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": None,
39+
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
4240
},
4341
token_cache=TokenCache())
4442

0 commit comments

Comments
 (0)