Skip to content

Commit 8c65f4a

Browse files
authored
{Auth} Bring back get_msal_token for acquiring VM SSH certificate (#31082)
1 parent 47d26b5 commit 8c65f4a

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

src/azure-cli-core/azure/cli/core/_profile.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,26 +299,24 @@ def logout_all(self):
299299
identity.logout_all_users()
300300
identity.logout_all_service_principal()
301301

302-
def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, aux_tenants=None):
302+
def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, aux_tenants=None,
303+
sdk_credential=True):
303304
"""Get a credential compatible with Track 2 SDK."""
304305
if aux_tenants and aux_subscriptions:
305306
raise CLIError("Please specify only one of aux_subscriptions and aux_tenants, not both")
306307

307308
account = self.get_subscription(subscription_id)
308309

309310
managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account)
310-
311+
external_credentials = None
311312
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
312313
# Cloud Shell
313314
from .auth.msal_credentials import CloudShellCredential
314-
# The credential must be wrapped by CredentialAdaptor so that it can work with SDK.
315-
sdk_cred = CredentialAdaptor(CloudShellCredential())
315+
cred = CloudShellCredential()
316316

317317
elif managed_identity_type:
318318
# managed identity
319-
# The credential must be wrapped by CredentialAdaptor so that it can work with SDK.
320319
cred = ManagedIdentityAuth.credential_factory(managed_identity_type, managed_identity_id)
321-
sdk_cred = CredentialAdaptor(cred)
322320

323321
else:
324322
# user and service principal
@@ -332,13 +330,15 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
332330
if sub[_TENANT_ID] != account[_TENANT_ID]:
333331
external_tenants.append(sub[_TENANT_ID])
334332

335-
credential = self._create_credential(account)
333+
cred = self._create_credential(account)
336334
external_credentials = []
337335
for external_tenant in external_tenants:
338336
external_credentials.append(self._create_credential(account, tenant_id=external_tenant))
339-
sdk_cred = CredentialAdaptor(credential, auxiliary_credentials=external_credentials)
340337

341-
return (sdk_cred,
338+
# Wrapping the credential with CredentialAdaptor makes it compatible with SDK.
339+
cred_result = CredentialAdaptor(cred, auxiliary_credentials=external_credentials) if sdk_credential else cred
340+
341+
return (cred_result,
342342
str(account[_SUBSCRIPTION_ID]),
343343
str(account[_TENANT_ID]))
344344

@@ -401,6 +401,15 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
401401
None if tenant else str(account[_SUBSCRIPTION_ID]),
402402
str(tenant if tenant else account[_TENANT_ID]))
403403

404+
def get_msal_token(self, scopes, data):
405+
"""Get VM SSH certificate. DO NOT use it for other purposes. To get an access token, use get_raw_token instead.
406+
"""
407+
credential, _, _ = self.get_login_credentials(sdk_credential=False)
408+
from .auth.constants import ACCESS_TOKEN
409+
certificate_string = credential.acquire_token(scopes, data=data)[ACCESS_TOKEN]
410+
# The first value used to be username, but it is no longer used.
411+
return None, certificate_string
412+
404413
def _normalize_properties(self, user, subscriptions, is_service_principal, cert_sn_issuer_auth=None,
405414
assigned_identity_info=None):
406415
consolidated = []

src/azure-cli-core/azure/cli/core/tests/test_profile.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ def __init__(self, *args, **kwargs):
5050
# If acquire_token_scopes is checked, make sure to create a new instance of MsalCredentialStub
5151
# to avoid interference from other tests.
5252
self.acquire_token_scopes = None
53+
self.acquire_token_data=None
5354
super().__init__()
5455

5556
def acquire_token(self, scopes, **kwargs):
5657
self.acquire_token_scopes = scopes
58+
self.acquire_token_data = kwargs.get('data')
5759
return {
5860
'access_token': MOCK_ACCESS_TOKEN,
5961
'token_type': 'Bearer',
@@ -1287,6 +1289,31 @@ def cloud_shell_credential_factory():
12871289
with self.assertRaisesRegex(CLIError, 'Cloud Shell'):
12881290
profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id)
12891291

1292+
@mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential')
1293+
def test_get_msal_token(self, get_user_credential_mock):
1294+
credential_mock_temp = MsalCredentialStub()
1295+
get_user_credential_mock.return_value = credential_mock_temp
1296+
cli = DummyCli()
1297+
1298+
storage_mock = {'subscriptions': None}
1299+
profile = Profile(cli_ctx=cli, storage=storage_mock)
1300+
consolidated = profile._normalize_properties(self.user1,
1301+
[self.subscription1],
1302+
False, None, None)
1303+
profile._set_subscriptions(consolidated)
1304+
1305+
MOCK_DATA = {
1306+
'key_id': 'test',
1307+
'req_cnf': 'test',
1308+
'token_type': 'ssh-cert'
1309+
}
1310+
result = profile.get_msal_token(['https://pas.windows.net/CheckMyAccess/Linux/.default'],
1311+
MOCK_DATA)
1312+
1313+
assert result == (None, MOCK_ACCESS_TOKEN)
1314+
assert credential_mock_temp.acquire_token_scopes == ['https://pas.windows.net/CheckMyAccess/Linux/.default']
1315+
assert credential_mock_temp.acquire_token_data == MOCK_DATA
1316+
12901317
@mock.patch('azure.cli.core.auth.identity.Identity.logout_service_principal')
12911318
@mock.patch('azure.cli.core.auth.identity.Identity.logout_user')
12921319
def test_logout(self, logout_user_mock, logout_service_principal_mock):

0 commit comments

Comments
 (0)