@@ -138,6 +138,10 @@ def create_instance(
138138 location = location ,
139139 )
140140
141+ managed_identity_resource_group , managed_identity_name = parse_vm_managed_identity (
142+ self .config .vm_managed_identity
143+ )
144+
141145 base_tags = {
142146 "owner" : "dstack" ,
143147 "dstack_project" : instance_config .project_name ,
@@ -161,7 +165,8 @@ def create_instance(
161165 network_security_group = network_security_group ,
162166 network = network ,
163167 subnet = subnet ,
164- managed_identity = None ,
168+ managed_identity_name = managed_identity_name ,
169+ managed_identity_resource_group = managed_identity_resource_group ,
165170 image_reference = _get_image_ref (
166171 compute_client = self ._compute_client ,
167172 location = location ,
@@ -257,7 +262,8 @@ def create_gateway(
257262 network_security_group = network_security_group ,
258263 network = network ,
259264 subnet = subnet ,
260- managed_identity = None ,
265+ managed_identity_name = None ,
266+ managed_identity_resource_group = None ,
261267 image_reference = _get_gateway_image_ref (),
262268 vm_size = "Standard_B1ms" ,
263269 instance_name = instance_name ,
@@ -340,6 +346,21 @@ def get_resource_group_network_subnet_or_error(
340346 return resource_group , network_name , subnet_name
341347
342348
349+ def parse_vm_managed_identity (
350+ vm_managed_identity : Optional [str ],
351+ ) -> Tuple [Optional [str ], Optional [str ]]:
352+ if vm_managed_identity is None :
353+ return None , None
354+ try :
355+ resource_group , managed_identity = vm_managed_identity .split ("/" )
356+ return resource_group , managed_identity
357+ except Exception :
358+ raise ComputeError (
359+ "`vm_managed_identity` specified in incorrect format."
360+ " Supported format: 'managedIdentityResourceGroup/managedIdentityName'"
361+ )
362+
363+
343364def _parse_config_vpc_id (vpc_id : str ) -> Tuple [str , str ]:
344365 resource_group , network_name = vpc_id .split ("/" )
345366 return resource_group , network_name
@@ -468,7 +489,8 @@ def _launch_instance(
468489 network_security_group : str ,
469490 network : str ,
470491 subnet : str ,
471- managed_identity : Optional [str ],
492+ managed_identity_name : Optional [str ],
493+ managed_identity_resource_group : Optional [str ],
472494 image_reference : ImageReference ,
473495 vm_size : str ,
474496 instance_name : str ,
@@ -490,6 +512,20 @@ def _launch_instance(
490512 public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration (
491513 name = "public_ip_config" ,
492514 )
515+ managed_identity = None
516+ if managed_identity_name is not None :
517+ if managed_identity_resource_group is None :
518+ managed_identity_resource_group = resource_group
519+ managed_identity = VirtualMachineIdentity (
520+ type = ResourceIdentityType .USER_ASSIGNED ,
521+ user_assigned_identities = {
522+ azure_utils .get_managed_identity_id (
523+ subscription_id ,
524+ managed_identity_resource_group ,
525+ managed_identity_name ,
526+ ): UserAssignedIdentitiesValue (),
527+ },
528+ )
493529 try :
494530 poller = compute_client .virtual_machines .begin_create_or_update (
495531 resource_group ,
@@ -554,16 +590,7 @@ def _launch_instance(
554590 ),
555591 priority = "Spot" if spot else "Regular" ,
556592 eviction_policy = "Delete" if spot else None ,
557- identity = None
558- if managed_identity is None
559- else VirtualMachineIdentity (
560- type = ResourceIdentityType .USER_ASSIGNED ,
561- user_assigned_identities = {
562- azure_utils .get_managed_identity_id (
563- subscription_id , resource_group , managed_identity
564- ): UserAssignedIdentitiesValue ()
565- },
566- ),
593+ identity = managed_identity ,
567594 user_data = base64 .b64encode (user_data .encode ()).decode (),
568595 tags = tags ,
569596 ),
0 commit comments