Skip to content

Commit f00aeae

Browse files
authored
Support vm_managed_identity for Azure (#2608)
1 parent c318d83 commit f00aeae

File tree

5 files changed

+73
-13
lines changed

5 files changed

+73
-13
lines changed

docs/docs/concepts/backends.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ There are two ways to configure Azure: using a client secret or using the defaul
338338
"Microsoft.Compute/disks/write",
339339
"Microsoft.Compute/disks/read",
340340
"Microsoft.Compute/disks/delete",
341+
"Microsoft.ManagedIdentity/userAssignedIdentities/assign/action",
342+
"Microsoft.ManagedIdentity/userAssignedIdentities/read",
341343
"Microsoft.Network/networkSecurityGroups/*",
342344
"Microsoft.Network/locations/*",
343345
"Microsoft.Network/virtualNetworks/*",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ azure = [
139139
"azure-mgmt-network>=23.0.0,<28.0.0",
140140
"azure-mgmt-resource>=22.0.0",
141141
"azure-mgmt-authorization>=3.0.0",
142+
"azure-mgmt-msi>=7.0.0",
142143
"dstack[server]",
143144
]
144145
gcp = [

src/dstack/_internal/core/backends/azure/compute.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ def create_instance(
138138
location=location,
139139
)
140140

141+
managed_identity_resource_group, managed_identity_name = parse_vm_managed_identity(
142+
self.config.vm_managed_identity
143+
)
144+
141145
base_tags = {
142146
"owner": "dstack",
143147
"dstack_project": instance_config.project_name,
@@ -161,7 +165,8 @@ def create_instance(
161165
network_security_group=network_security_group,
162166
network=network,
163167
subnet=subnet,
164-
managed_identity=None,
168+
managed_identity_name=managed_identity_name,
169+
managed_identity_resource_group=managed_identity_resource_group,
165170
image_reference=_get_image_ref(
166171
compute_client=self._compute_client,
167172
location=location,
@@ -257,7 +262,8 @@ def create_gateway(
257262
network_security_group=network_security_group,
258263
network=network,
259264
subnet=subnet,
260-
managed_identity=None,
265+
managed_identity_name=None,
266+
managed_identity_resource_group=None,
261267
image_reference=_get_gateway_image_ref(),
262268
vm_size="Standard_B1ms",
263269
instance_name=instance_name,
@@ -340,6 +346,21 @@ def get_resource_group_network_subnet_or_error(
340346
return resource_group, network_name, subnet_name
341347

342348

349+
def parse_vm_managed_identity(
350+
vm_managed_identity: Optional[str],
351+
) -> Tuple[Optional[str], Optional[str]]:
352+
if vm_managed_identity is None:
353+
return None, None
354+
try:
355+
resource_group, managed_identity = vm_managed_identity.split("/")
356+
return resource_group, managed_identity
357+
except Exception:
358+
raise ComputeError(
359+
"`vm_managed_identity` specified in incorrect format."
360+
" Supported format: 'managedIdentityResourceGroup/managedIdentityName'"
361+
)
362+
363+
343364
def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
344365
resource_group, network_name = vpc_id.split("/")
345366
return resource_group, network_name
@@ -468,7 +489,8 @@ def _launch_instance(
468489
network_security_group: str,
469490
network: str,
470491
subnet: str,
471-
managed_identity: Optional[str],
492+
managed_identity_name: Optional[str],
493+
managed_identity_resource_group: Optional[str],
472494
image_reference: ImageReference,
473495
vm_size: str,
474496
instance_name: str,
@@ -490,6 +512,20 @@ def _launch_instance(
490512
public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration(
491513
name="public_ip_config",
492514
)
515+
managed_identity = None
516+
if managed_identity_name is not None:
517+
if managed_identity_resource_group is None:
518+
managed_identity_resource_group = resource_group
519+
managed_identity = VirtualMachineIdentity(
520+
type=ResourceIdentityType.USER_ASSIGNED,
521+
user_assigned_identities={
522+
azure_utils.get_managed_identity_id(
523+
subscription_id,
524+
managed_identity_resource_group,
525+
managed_identity_name,
526+
): UserAssignedIdentitiesValue(),
527+
},
528+
)
493529
try:
494530
poller = compute_client.virtual_machines.begin_create_or_update(
495531
resource_group,
@@ -554,16 +590,7 @@ def _launch_instance(
554590
),
555591
priority="Spot" if spot else "Regular",
556592
eviction_policy="Delete" if spot else None,
557-
identity=None
558-
if managed_identity is None
559-
else VirtualMachineIdentity(
560-
type=ResourceIdentityType.USER_ASSIGNED,
561-
user_assigned_identities={
562-
azure_utils.get_managed_identity_id(
563-
subscription_id, resource_group, managed_identity
564-
): UserAssignedIdentitiesValue()
565-
},
566-
),
593+
identity=managed_identity,
567594
user_data=base64.b64encode(user_data.encode()).decode(),
568595
tags=tags,
569596
),

src/dstack/_internal/core/backends/azure/configurator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import azure.core.exceptions
66
from azure.core.credentials import TokenCredential
7+
from azure.mgmt import msi as msi_mgmt
78
from azure.mgmt import network as network_mgmt
89
from azure.mgmt import resource as resource_mgmt
910
from azure.mgmt import subscription as subscription_mgmt
@@ -97,6 +98,7 @@ def validate_config(self, config: AzureBackendConfigWithCreds, default_creds_ena
9798
self._check_config_locations(config)
9899
self._check_config_tags(config)
99100
self._check_config_resource_group(config=config, credential=credential)
101+
self._check_config_vm_managed_identity(config=config, credential=credential)
100102
self._check_config_vpc(config=config, credential=credential)
101103

102104
def create_backend(
@@ -260,6 +262,25 @@ def _check_config_vpc(
260262
except BackendError as e:
261263
raise ServerClientError(e.args[0])
262264

265+
def _check_config_vm_managed_identity(
266+
self, config: AzureBackendConfigWithCreds, credential: auth.AzureCredential
267+
):
268+
try:
269+
resource_group, identity_name = compute.parse_vm_managed_identity(
270+
config.vm_managed_identity
271+
)
272+
except BackendError as e:
273+
raise ServerClientError(e.args[0])
274+
if resource_group is None or identity_name is None:
275+
return
276+
msi_client = msi_mgmt.ManagedServiceIdentityClient(credential, config.subscription_id)
277+
try:
278+
msi_client.user_assigned_identities.get(resource_group, identity_name)
279+
except azure.core.exceptions.ResourceNotFoundError:
280+
raise ServerClientError(
281+
f"Managed identity {identity_name} not found in resource group {resource_group}"
282+
)
283+
263284
def _set_client_creds_tenant_id(
264285
self,
265286
creds: AzureClientCreds,

src/dstack/_internal/core/backends/azure/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ class AzureBackendConfig(CoreModel):
6262
)
6363
),
6464
] = None
65+
vm_managed_identity: Annotated[
66+
Optional[str],
67+
Field(
68+
description=(
69+
"The managed identity to associate with provisioned VMs."
70+
" Must have a format `managedIdentityResourceGroup/managedIdentityName`"
71+
)
72+
),
73+
] = None
6574
tags: Annotated[
6675
Optional[Dict[str, str]],
6776
Field(description="The tags that will be assigned to resources created by `dstack`"),

0 commit comments

Comments
 (0)