5656_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
5757_COMMON_TENANT = 'common'
5858
59- _MSI_ACCOUNT_NAME = 'MSI@'
6059_TENANT_LEVEL_ACCOUNT_NAME = 'N/A(tenant level account)'
6160
6261
@@ -254,9 +253,10 @@ def _new_account():
254253 s .state = StateType .enabled
255254 return s
256255
257- def find_subscriptions_in_vm_with_msi (self , msi_port ):
256+ def find_subscriptions_in_vm_with_msi (self , msi_port , identity_id = None ):
258257 import jwt
259- _ , token , _ = Profile .get_msi_token (CLOUD .endpoints .active_directory_resource_id , msi_port )
258+ token , identity_id_type = Profile .get_msi_token (CLOUD .endpoints .active_directory_resource_id ,
259+ msi_port , identity_id , for_login = True )
260260 logger .info ('MSI: token was retrieved. Now trying to initialize local accounts...' )
261261 decode = jwt .decode (token , verify = False , algorithms = ['RS256' ])
262262 tenant = decode ['tid' ]
@@ -265,32 +265,35 @@ def find_subscriptions_in_vm_with_msi(self, msi_port):
265265 subscriptions = subscription_finder .find_from_raw_token (tenant , token )
266266 if not subscriptions :
267267 raise CLIError ('No access was configured for the VM, hence no subscriptions were found' )
268- consolidated = Profile ._normalize_properties ('VM' , subscriptions , is_service_principal = True )
268+ base_name = '{}-{}' .format (identity_id_type , identity_id ) if identity_id else identity_id_type
269+ user = 'userAssignedIdentity' if identity_id else 'systemAssignedIdentity'
270+ consolidated = Profile ._normalize_properties (user , subscriptions , is_service_principal = True )
269271 for s in consolidated :
270272 # use a special name to trigger a special token acquisition
271- s [_SUBSCRIPTION_NAME ] = "{}{}" .format (_MSI_ACCOUNT_NAME , msi_port )
272- self ._set_subscriptions (consolidated )
273+ s [_SUBSCRIPTION_NAME ] = "{}@{}" .format (base_name , msi_port )
274+ # key-off subscription name to allow accounts with same id(but under different identities)
275+ self ._set_subscriptions (consolidated , key_name = _SUBSCRIPTION_NAME )
273276 return deepcopy (consolidated )
274277
275- def _set_subscriptions (self , new_subscriptions , merge = True ):
278+ def _set_subscriptions (self , new_subscriptions , merge = True , key_name = _SUBSCRIPTION_ID ):
276279 existing_ones = self .load_cached_subscriptions (all_clouds = True )
277280 active_one = next ((x for x in existing_ones if x .get (_IS_DEFAULT_SUBSCRIPTION )), None )
278- active_subscription_id = active_one [_SUBSCRIPTION_ID ] if active_one else None
281+ active_subscription_id = active_one [key_name ] if active_one else None
279282 active_cloud = get_active_cloud ()
280283 default_sub_id = None
281284
282285 # merge with existing ones
283286 if merge :
284- dic = collections .OrderedDict ((x [_SUBSCRIPTION_ID ], x ) for x in existing_ones )
287+ dic = collections .OrderedDict ((x [key_name ], x ) for x in existing_ones )
285288 else :
286289 dic = collections .OrderedDict ()
287290
288- dic .update ((x [_SUBSCRIPTION_ID ], x ) for x in new_subscriptions )
291+ dic .update ((x [key_name ], x ) for x in new_subscriptions )
289292 subscriptions = list (dic .values ())
290293 if subscriptions :
291294 if active_one :
292295 new_active_one = next (
293- (x for x in new_subscriptions if x [_SUBSCRIPTION_ID ] == active_subscription_id ),
296+ (x for x in new_subscriptions if x [key_name ] == active_subscription_id ),
294297 None )
295298
296299 for s in subscriptions :
@@ -384,15 +387,26 @@ def get_access_token_for_resource(self, username, tenant, resource):
384387 username , tenant , resource )
385388 return access_token
386389
390+ @staticmethod
391+ def _try_parse_for_msi_port (subscription_name ):
392+ if '@' in subscription_name :
393+ try :
394+ parts = subscription_name .split ('@' , 1 )
395+ return parts [0 ], int (parts [1 ])
396+ except ValueError :
397+ pass
398+ return None , None
399+
387400 def get_login_credentials (self , resource = CLOUD .endpoints .active_directory_resource_id ,
388401 subscription_id = None ):
389402 account = self .get_subscription (subscription_id )
390403 user_type = account [_USER_ENTITY ][_USER_TYPE ]
391404 username_or_sp_id = account [_USER_ENTITY ][_USER_NAME ]
392405
393406 def _retrieve_token ():
394- if account [_SUBSCRIPTION_NAME ].startswith (_MSI_ACCOUNT_NAME ):
395- return Profile .get_msi_token (resource , account [_SUBSCRIPTION_NAME ][len (_MSI_ACCOUNT_NAME ):])
407+ identity_id , msi_port = Profile ._try_parse_for_msi_port (account [_SUBSCRIPTION_NAME ])
408+ if msi_port is not None :
409+ return Profile .get_msi_token (resource , msi_port , identity_id )
396410 elif user_type == _USER :
397411 return self ._creds_cache .retrieve_token_for_user (username_or_sp_id ,
398412 account [_TENANT_ID ], resource )
@@ -435,8 +449,10 @@ def get_raw_token(self, resource, subscription=None):
435449 str (account [_TENANT_ID ]))
436450
437451 def refresh_accounts (self , subscription_finder = None ):
452+ import re
438453 subscriptions = self .load_cached_subscriptions ()
439- to_refresh = [s for s in subscriptions if not s [_SUBSCRIPTION_NAME ].startswith (_MSI_ACCOUNT_NAME )]
454+ # filter away MSI related ones whose name always end with '@<port-number>'
455+ to_refresh = [s for s in subscriptions if not re .match ('@[0-9]+$' , s [_SUBSCRIPTION_NAME ])]
440456 not_to_refresh = [s for s in subscriptions if s not in to_refresh ]
441457
442458 from azure .cli .core ._debug import allow_debug_adal_connection
@@ -536,13 +552,43 @@ def get_installation_id(self):
536552 return installation_id
537553
538554 @staticmethod
539- def get_msi_token (resource , port ):
555+ def get_msi_token (resource , port , identity_id = None , for_login = False ):
540556 import requests
541557 import time
558+ from msrestazure .tools import is_valid_resource_id
559+ _System_Assigned_Id_Type = 'MSI'
560+ _User_Assigned_Client_Id_type = 'MSIClient'
561+ _User_assigned_Object_Id_Type = 'MSIObject'
562+ _User_assigned_Resource_Id_Type = 'MSIResource'
563+
542564 request_uri = 'http://localhost:{}/oauth2/token' .format (port )
543565 payload = {
544566 'resource' : resource
545567 }
568+ identity_id_type = None
569+ if for_login : # we will figure out the right type of id here
570+ if not identity_id :
571+ identity_id_type = _System_Assigned_Id_Type
572+ elif is_valid_resource_id (identity_id ):
573+ payload ['msi_res_id' ] = identity_id
574+ identity_id_type = _User_assigned_Resource_Id_Type
575+ else : # try to sniff it
576+ payload ['client_id' ] = identity_id
577+ identity_id_type = _User_Assigned_Client_Id_type
578+ result = requests .post (request_uri , data = payload , headers = {'Metadata' : 'true' })
579+ if result .status_code != 200 :
580+ payload .pop ('client_id' )
581+ payload ['object_id' ] = identity_id
582+ identity_id_type = _User_assigned_Object_Id_Type
583+ else :
584+ parts = identity_id .split ('-' , 1 )
585+ identity_id_type = parts [0 ]
586+ if parts [0 ] == _User_assigned_Resource_Id_Type :
587+ payload ['msi_res_id' ] = parts [1 ]
588+ elif parts [0 ] == _User_Assigned_Client_Id_type :
589+ payload ['client_id' ] = parts [1 ]
590+ elif parts [0 ] == _User_assigned_Object_Id_Type :
591+ payload ['object_id' ] = parts [1 ]
546592
547593 # retry as the token endpoint might not be available yet, one example is you use CLI in a
548594 # custom script extension of VMSS, which might get provisioned before the MSI extensioon
@@ -567,6 +613,8 @@ def get_msi_token(resource, port):
567613 logger .debug ('MSI: token retrieved' )
568614 break
569615 token_entry = json .loads (result .content .decode ())
616+ if for_login :
617+ return token_entry ['access_token' ], identity_id_type
570618 return token_entry ['token_type' ], token_entry ['access_token' ], token_entry
571619
572620
0 commit comments