Skip to content

Commit 4680e5c

Browse files
committed
mi-msal
1 parent 89d2232 commit 4680e5c

File tree

2 files changed

+56
-18
lines changed

2 files changed

+56
-18
lines changed

src/azure-cli-core/azure/cli/core/_profile.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,10 @@ def login(self,
227227

228228
def login_with_managed_identity(self, identity_id=None, client_id=None, object_id=None, resource_id=None,
229229
allow_no_subscriptions=None):
230-
if _on_azure_arc():
231-
return self.login_with_managed_identity_azure_arc(
232-
identity_id=identity_id, allow_no_subscriptions=allow_no_subscriptions)
230+
if _use_msal_managed_identity(self.cli_ctx):
231+
return self.login_with_managed_identity_msal(
232+
client_id=client_id, object_id=object_id, resource_id=resource_id,
233+
allow_no_subscriptions=allow_no_subscriptions)
233234

234235
import jwt
235236
from azure.mgmt.core.tools import is_valid_resource_id
@@ -310,13 +311,31 @@ def login_with_managed_identity(self, identity_id=None, client_id=None, object_i
310311
self._set_subscriptions(consolidated)
311312
return deepcopy(consolidated)
312313

313-
def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subscriptions=None):
314+
def login_with_managed_identity_msal(self, client_id=None, object_id=None, resource_id=None,
315+
allow_no_subscriptions=None):
314316
import jwt
315-
identity_type = MsiAccountTypes.system_assigned
316-
from .auth.msal_credentials import ManagedIdentityCredential
317317
from .auth.constants import ACCESS_TOKEN
318318

319-
cred = ManagedIdentityCredential()
319+
id_arg_count = len([arg for arg in (client_id, object_id, resource_id) if arg])
320+
if id_arg_count > 1:
321+
raise CLIError('Usage error: Provide only one of --client-id, --object-id, --resource-id.')
322+
323+
identity_type = None
324+
identity_id = None
325+
if id_arg_count == 0:
326+
identity_type = MsiAccountTypes.system_assigned
327+
identity_id = None
328+
elif client_id:
329+
identity_type = MsiAccountTypes.user_assigned_client_id
330+
identity_id = client_id
331+
elif object_id:
332+
identity_type = MsiAccountTypes.user_assigned_object_id
333+
identity_id = object_id
334+
elif resource_id:
335+
identity_type = MsiAccountTypes.user_assigned_resource_id
336+
identity_id = resource_id
337+
338+
cred = MsiAccountTypes.msal_credential_factory(identity_type, identity_id)
320339
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
321340
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
322341
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
@@ -405,10 +424,10 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
405424

406425
elif managed_identity_type:
407426
# managed identity
408-
if _on_azure_arc():
409-
from .auth.msal_credentials import ManagedIdentityCredential
427+
if _use_msal_managed_identity(self.cli_ctx):
410428
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
411-
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
429+
cred = MsiAccountTypes.msal_credential_factory(managed_identity_type, managed_identity_id)
430+
sdk_cred = CredentialAdaptor(cred)
412431
else:
413432
# The resource is merely used by msrestazure to get the first access token.
414433
# It is not actually used in an API invocation.
@@ -466,9 +485,9 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
466485
# managed identity
467486
if tenant:
468487
raise CLIError("Tenant shouldn't be specified for managed identity account")
469-
if _on_azure_arc():
470-
from .auth.msal_credentials import ManagedIdentityCredential
471-
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
488+
if _use_msal_managed_identity(self.cli_ctx):
489+
cred = MsiAccountTypes.msal_credential_factory(managed_identity_type, managed_identity_id)
490+
sdk_cred = CredentialAdaptor(cred)
472491
else:
473492
from .auth.util import scopes_to_resource
474493
sdk_cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
@@ -816,6 +835,18 @@ def msi_auth_factory(cli_account_name, identity, resource):
816835
return MSIAuthenticationWrapper(resource=resource, msi_res_id=identity)
817836
raise ValueError("unrecognized msi account name '{}'".format(cli_account_name))
818837

838+
@staticmethod
839+
def msal_credential_factory(id_type, id_value):
840+
from azure.cli.core.auth.msal_credentials import ManagedIdentityCredential
841+
if id_type == MsiAccountTypes.system_assigned:
842+
return ManagedIdentityCredential()
843+
if id_type == MsiAccountTypes.user_assigned_client_id:
844+
return ManagedIdentityCredential(client_id=id_value)
845+
if id_type == MsiAccountTypes.user_assigned_object_id:
846+
return ManagedIdentityCredential(object_id=id_value)
847+
if id_type == MsiAccountTypes.user_assigned_resource_id:
848+
return ManagedIdentityCredential(resource_id=id_value)
849+
raise ValueError("Unrecognized managed identity ID type '{}'".format(id_type))
819850

820851
class SubscriptionFinder:
821852
# An ARM client. It finds subscriptions for a user or service principal. It shouldn't do any
@@ -982,7 +1013,9 @@ def _create_identity_instance(cli_ctx, authority, tenant_id=None, client_id=None
9821013
instance_discovery=instance_discovery)
9831014

9841015

985-
def _on_azure_arc():
1016+
def _use_msal_managed_identity(cli_ctx):
9861017
# This indicates an Azure Arc-enabled server
9871018
from msal.managed_identity import get_managed_identity_source, AZURE_ARC
988-
return get_managed_identity_source() == AZURE_ARC
1019+
# PREVIEW: Use core.use_msal_managed_identity=true to enable managed identity authentication with MSAL
1020+
use_msal_managed_identity = cli_ctx.config.getboolean('core', 'use_msal_managed_identity', fallback=False)
1021+
return use_msal_managed_identity or get_managed_identity_source() == AZURE_ARC

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from knack.log import get_logger
1111
from knack.util import CLIError
1212
from msal import (PublicClientApplication, ConfidentialClientApplication,
13-
ManagedIdentityClient, SystemAssignedManagedIdentity)
13+
ManagedIdentityClient, SystemAssignedManagedIdentity, UserAssignedManagedIdentity)
1414

1515
from .constants import AZURE_CLI_CLIENT_ID
1616
from .util import check_result
@@ -131,9 +131,14 @@ class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
131131
Currently, only Azure Arc's system-assigned managed identity is supported.
132132
"""
133133

134-
def __init__(self):
134+
def __init__(self, client_id=None, resource_id=None, object_id=None):
135135
import requests
136-
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())
136+
if client_id or resource_id or object_id:
137+
managed_identity = UserAssignedManagedIdentity(
138+
client_id=client_id, resource_id=resource_id, object_id=object_id)
139+
else:
140+
managed_identity = SystemAssignedManagedIdentity()
141+
self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session())
137142

138143
def acquire_token(self, scopes, **kwargs):
139144
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)

0 commit comments

Comments
 (0)