Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,23 @@ def __init__(self, client_id, username, **kwargs):

self._account = accounts[0]

def acquire_token(self, scopes, claims_challenge=None, **kwargs):
def acquire_token(self, scopes, claims_challenge=None, data=None, **kwargs):
# scopes must be a list.
# For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default']
# data is only used for acquiring VM SSH certificate. DO NOT use it for other purposes.
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, kwargs=%r",
scopes, claims_challenge, kwargs)
logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, data=%r, kwargs=%r",
scopes, claims_challenge, data, kwargs)

if claims_challenge:
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
self._msal_app.authority.tenant, claims_challenge)

# Only pass data to MSAL if it is set. Passing data=None will cause failure in MSAL:
# AttributeError: 'NoneType' object has no attribute 'get'
if data is not None:
kwargs['data'] = data

result = self._msal_app.acquire_token_silent_with_error(
scopes, self._account, claims_challenge=claims_challenge, **kwargs)

Expand Down Expand Up @@ -105,8 +112,13 @@ def __init__(self, client_id, client_credential, **kwargs):
"""
self._msal_app = ConfidentialClientApplication(client_id, client_credential=client_credential, **kwargs)

def acquire_token(self, scopes, **kwargs):
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
def acquire_token(self, scopes, data=None, **kwargs):
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, data=%r, kwargs=%r",
scopes, data, kwargs)

if data is not None:
kwargs['data'] = data

result = self._msal_app.acquire_token_for_client(scopes, **kwargs)
check_result(result)
return result
Expand All @@ -126,8 +138,13 @@ def __init__(self):
# token_cache=...
)

def acquire_token(self, scopes, **kwargs):
logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
def acquire_token(self, scopes, data=None, **kwargs):
logger.debug("CloudShellCredential.acquire_token: scopes=%r, data=%r, kwargs=%r",
scopes, data, kwargs)

if data is not None:
kwargs['data'] = data

result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs)
check_result(result, scopes=scopes)
return result
Expand All @@ -147,8 +164,13 @@ def __init__(self, client_id=None, resource_id=None, object_id=None):
managed_identity = SystemAssignedManagedIdentity()
self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session())

def acquire_token(self, scopes, **kwargs):
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
def acquire_token(self, scopes, data=None, **kwargs):
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, data=%r, kwargs=%r",
scopes, data, kwargs)

if data is not None:
from azure.cli.core.azclierror import AuthenticationError
raise AuthenticationError("VM SSH currently doesn't support managed identity.")

from .util import scopes_to_resource
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------


import unittest
from unittest import mock

from ..msal_credentials import UserCredential

MOCK_ACCOUNT = {
'account_source': 'authorization_code',
'authority_type': 'MSSTS',
'environment': 'login.microsoftonline.com',
# random GUID generated by uuid.uuid4()
'home_account_id': '9d486bfc-8d91-4a65-a23e-33e1f01a1718.e4e8e73b-5f99-4bd5-bdac-60b916a7343b',
'local_account_id': '9d486bfc-8d91-4a65-a23e-33e1f01a1718',
'realm': 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b',
'username': 'test@microsoft.com'
}

MOCK_SCOPES = ['https://management.core.windows.net//.default']

MOCK_ACCESS_TOKEN = "mock_access_token"
MOCK_MSAL_TOKEN = {
'access_token': MOCK_ACCESS_TOKEN,
'token_type': 'Bearer',
'expires_in': 1800,
'token_source': 'cache'
}

MOCK_CLAIMS = {"test_claims": "value2"}

MOCK_DATA = {
'key_id': 'test',
'req_cnf': 'test',
'token_type': 'ssh-cert'
}
MOCK_CERTIFICATE= "mock_certificate"
MOCK_MSAL_CERTIFICATE = {
'access_token': MOCK_CERTIFICATE,
'client_info': 'test',
'expires_in': 3599,
'ext_expires_in': 3599,
'foci': '1',
'id_token': 'test',
'id_token_claims': {
'preferred_username': 'test@microsoft.com',
'tid': 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b'
},
'refresh_token': 'test',
'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default',
'token_source': 'identity_provider',
'token_type': 'ssh-cert'
}


class AuthorityStub:
def __init__(self):
self.tenant = 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b'

class PublicClientApplicationStub:

def __init__(self, client_id, **kwargs):
self.client_id = client_id
self.authority = AuthorityStub()
self.kwargs = kwargs
self.acquire_token_silent_with_error_scopes = None
self.acquire_token_silent_with_error_claims_challenge = None
self.acquire_token_silent_with_error_kwargs = None
super().__init__()

def get_accounts(self, username):
return [MOCK_ACCOUNT]

def acquire_token_silent_with_error(self, scopes, account, **kwargs):
self.acquire_token_silent_with_error_scopes = scopes
self.acquire_token_silent_with_error_claims_challenge = scopes
self.acquire_token_silent_with_error_kwargs = kwargs
if 'data' in kwargs:
return MOCK_MSAL_CERTIFICATE
return MOCK_MSAL_TOKEN


class TestUserCredential(unittest.TestCase):

@mock.patch('azure.cli.core.auth.msal_credentials.PublicClientApplication')
def test_get_token(self, public_client_application_mock):
public_client_application_mock.side_effect = PublicClientApplicationStub

msal_credential = UserCredential('test_client_id', 'test_username')
msal_app = msal_credential._msal_app
assert msal_credential._account == MOCK_ACCOUNT

result = msal_credential.acquire_token(MOCK_SCOPES)
assert result == MOCK_MSAL_TOKEN
assert msal_app.acquire_token_silent_with_error_scopes == MOCK_SCOPES
# Make sure data is not passed to MSAL
assert 'data' not in msal_app.acquire_token_silent_with_error_kwargs

result = msal_credential.acquire_token(MOCK_SCOPES, claims_challenge=MOCK_CLAIMS)
assert result == MOCK_MSAL_TOKEN
assert msal_app.acquire_token_silent_with_error_scopes == MOCK_SCOPES
assert msal_app.acquire_token_silent_with_error_kwargs['claims_challenge'] == MOCK_CLAIMS

result = msal_credential.acquire_token(['https://pas.windows.net/CheckMyAccess/Linux/.default'],
data=MOCK_DATA)
assert result == MOCK_MSAL_CERTIFICATE
assert msal_app.acquire_token_silent_with_error_scopes == ['https://pas.windows.net/CheckMyAccess/Linux/.default']
assert msal_app.acquire_token_silent_with_error_kwargs['data'] == MOCK_DATA


if __name__ == '__main__':
unittest.main()
Loading