Skip to content

Commit a2a6ff8

Browse files
committed
msal-credential-data
1 parent 4f49747 commit a2a6ff8

File tree

2 files changed

+146
-9
lines changed

2 files changed

+146
-9
lines changed

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,23 @@ def __init__(self, client_id, username, **kwargs):
4343

4444
self._account = accounts[0]
4545

46-
def acquire_token(self, scopes, claims_challenge=None, **kwargs):
46+
def acquire_token(self, scopes, claims_challenge=None, data=None, **kwargs):
4747
# scopes must be a list.
4848
# For acquiring SSH certificate, scopes is ['https://pas.windows.net/CheckMyAccess/Linux/.default']
49+
# data is only used for acquiring VM SSH certificate. DO NOT use it for other purposes.
4950
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
50-
logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, kwargs=%r",
51-
scopes, claims_challenge, kwargs)
51+
logger.debug("UserCredential.acquire_token: scopes=%r, claims_challenge=%r, data=%r, kwargs=%r",
52+
scopes, claims_challenge, data, kwargs)
5253

5354
if claims_challenge:
5455
logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s',
5556
self._msal_app.authority.tenant, claims_challenge)
57+
58+
# Only pass data to MSAL if it is set. Passing data=None will cause failure in MSAL:
59+
# AttributeError: 'NoneType' object has no attribute 'get'
60+
if data is not None:
61+
kwargs['data'] = data
62+
5663
result = self._msal_app.acquire_token_silent_with_error(
5764
scopes, self._account, claims_challenge=claims_challenge, **kwargs)
5865

@@ -105,8 +112,13 @@ def __init__(self, client_id, client_credential, **kwargs):
105112
"""
106113
self._msal_app = ConfidentialClientApplication(client_id, client_credential=client_credential, **kwargs)
107114

108-
def acquire_token(self, scopes, **kwargs):
109-
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
115+
def acquire_token(self, scopes, data=None, **kwargs):
116+
logger.debug("ServicePrincipalCredential.acquire_token: scopes=%r, data=%r, kwargs=%r",
117+
scopes, data, kwargs)
118+
119+
if data is not None:
120+
kwargs['data'] = data
121+
110122
result = self._msal_app.acquire_token_for_client(scopes, **kwargs)
111123
check_result(result)
112124
return result
@@ -126,8 +138,13 @@ def __init__(self):
126138
# token_cache=...
127139
)
128140

129-
def acquire_token(self, scopes, **kwargs):
130-
logger.debug("CloudShellCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
141+
def acquire_token(self, scopes, data=None, **kwargs):
142+
logger.debug("CloudShellCredential.acquire_token: scopes=%r, data=%r, kwargs=%r",
143+
scopes, data, kwargs)
144+
145+
if data is not None:
146+
kwargs['data'] = data
147+
131148
result = self._msal_app.acquire_token_interactive(scopes, prompt="none", **kwargs)
132149
check_result(result, scopes=scopes)
133150
return result
@@ -147,8 +164,13 @@ def __init__(self, client_id=None, resource_id=None, object_id=None):
147164
managed_identity = SystemAssignedManagedIdentity()
148165
self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session())
149166

150-
def acquire_token(self, scopes, **kwargs):
151-
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
167+
def acquire_token(self, scopes, data=None, **kwargs):
168+
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, data=%r, kwargs=%r",
169+
scopes, data, kwargs)
170+
171+
if data is not None:
172+
from azure.cli.core.azclierror import AuthenticationError
173+
raise AuthenticationError("VM SSH currently doesn't support managed identity.")
152174

153175
from .util import scopes_to_resource
154176
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
7+
import unittest
8+
from unittest import mock
9+
10+
from ..msal_credentials import UserCredential
11+
12+
MOCK_ACCOUNT = {
13+
'account_source': 'authorization_code',
14+
'authority_type': 'MSSTS',
15+
'environment': 'login.microsoftonline.com',
16+
# random GUID generated by uuid.uuid4()
17+
'home_account_id': '9d486bfc-8d91-4a65-a23e-33e1f01a1718.e4e8e73b-5f99-4bd5-bdac-60b916a7343b',
18+
'local_account_id': '9d486bfc-8d91-4a65-a23e-33e1f01a1718',
19+
'realm': 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b',
20+
'username': '[email protected]'
21+
}
22+
23+
MOCK_SCOPES = ['https://management.core.windows.net//.default']
24+
25+
MOCK_ACCESS_TOKEN = "mock_access_token"
26+
MOCK_MSAL_TOKEN = {
27+
'access_token': MOCK_ACCESS_TOKEN,
28+
'token_type': 'Bearer',
29+
'expires_in': 1800,
30+
'token_source': 'cache'
31+
}
32+
33+
MOCK_CLAIMS = {"test_claims": "value2"}
34+
35+
MOCK_DATA = {
36+
'key_id': 'test',
37+
'req_cnf': 'test',
38+
'token_type': 'ssh-cert'
39+
}
40+
MOCK_CERTIFICATE= "mock_certificate"
41+
MOCK_MSAL_CERTIFICATE = {
42+
'access_token': MOCK_CERTIFICATE,
43+
'client_info': 'test',
44+
'expires_in': 3599,
45+
'ext_expires_in': 3599,
46+
'foci': '1',
47+
'id_token': 'test',
48+
'id_token_claims': {
49+
'preferred_username': '[email protected]',
50+
'tid': 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b'
51+
},
52+
'refresh_token': 'test',
53+
'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default',
54+
'token_source': 'identity_provider',
55+
'token_type': 'ssh-cert'
56+
}
57+
58+
59+
class AuthorityStub:
60+
def __init__(self):
61+
self.tenant = 'e4e8e73b-5f99-4bd5-bdac-60b916a7343b'
62+
63+
class PublicClientApplicationStub:
64+
65+
def __init__(self, client_id, **kwargs):
66+
self.client_id = client_id
67+
self.authority = AuthorityStub()
68+
self.kwargs = kwargs
69+
self.acquire_token_silent_with_error_scopes = None
70+
self.acquire_token_silent_with_error_claims_challenge = None
71+
self.acquire_token_silent_with_error_kwargs = None
72+
super().__init__()
73+
74+
def get_accounts(self, username):
75+
return [MOCK_ACCOUNT]
76+
77+
def acquire_token_silent_with_error(self, scopes, account, **kwargs):
78+
self.acquire_token_silent_with_error_scopes = scopes
79+
self.acquire_token_silent_with_error_claims_challenge = scopes
80+
self.acquire_token_silent_with_error_kwargs = kwargs
81+
if 'data' in kwargs:
82+
return MOCK_MSAL_CERTIFICATE
83+
return MOCK_MSAL_TOKEN
84+
85+
86+
class TestUserCredential(unittest.TestCase):
87+
88+
@mock.patch('azure.cli.core.auth.msal_credentials.PublicClientApplication')
89+
def test_get_token(self, public_client_application_mock):
90+
public_client_application_mock.side_effect = PublicClientApplicationStub
91+
92+
msal_credential = UserCredential('test_client_id', 'test_username')
93+
msal_app = msal_credential._msal_app
94+
assert msal_credential._account == MOCK_ACCOUNT
95+
96+
result = msal_credential.acquire_token(MOCK_SCOPES)
97+
assert result == MOCK_MSAL_TOKEN
98+
assert msal_app.acquire_token_silent_with_error_scopes == MOCK_SCOPES
99+
# Make sure data is not passed to MSAL
100+
assert 'data' not in msal_app.acquire_token_silent_with_error_kwargs
101+
102+
result = msal_credential.acquire_token(MOCK_SCOPES, claims_challenge=MOCK_CLAIMS)
103+
assert result == MOCK_MSAL_TOKEN
104+
assert msal_app.acquire_token_silent_with_error_scopes == MOCK_SCOPES
105+
assert msal_app.acquire_token_silent_with_error_kwargs['claims_challenge'] == MOCK_CLAIMS
106+
107+
result = msal_credential.acquire_token(['https://pas.windows.net/CheckMyAccess/Linux/.default'],
108+
data=MOCK_DATA)
109+
assert result == MOCK_MSAL_CERTIFICATE
110+
assert msal_app.acquire_token_silent_with_error_scopes == ['https://pas.windows.net/CheckMyAccess/Linux/.default']
111+
assert msal_app.acquire_token_silent_with_error_kwargs['data'] == MOCK_DATA
112+
113+
114+
if __name__ == '__main__':
115+
unittest.main()

0 commit comments

Comments
 (0)