Skip to content

Commit 01190a0

Browse files
committed
WIP: support GCP placement policies
1 parent d6e5569 commit 01190a0

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ComputeWithCreateInstanceSupport,
1616
ComputeWithGatewaySupport,
1717
ComputeWithMultinodeSupport,
18+
ComputeWithPlacementGroupSupport,
1819
ComputeWithVolumeSupport,
1920
generate_unique_gateway_instance_name,
2021
generate_unique_instance_name,
@@ -46,6 +47,7 @@
4647
InstanceType,
4748
Resources,
4849
)
50+
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
4951
from dstack._internal.core.models.resources import Memory, Range
5052
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
5153
from dstack._internal.core.models.volumes import (
@@ -74,6 +76,7 @@ class GCPVolumeDiskBackendData(CoreModel):
7476
class GCPCompute(
7577
ComputeWithCreateInstanceSupport,
7678
ComputeWithMultinodeSupport,
79+
ComputeWithPlacementGroupSupport,
7780
ComputeWithGatewaySupport,
7881
ComputeWithVolumeSupport,
7982
Compute,
@@ -89,6 +92,9 @@ def __init__(self, config: GCPConfig):
8992
self.routers_client = compute_v1.RoutersClient(credentials=self.credentials)
9093
self.tpu_client = tpu_v2.TpuClient(credentials=self.credentials)
9194
self.disk_client = compute_v1.DisksClient(credentials=self.credentials)
95+
self.resource_policies_client = compute_v1.ResourcePoliciesClient(
96+
credentials=self.credentials
97+
)
9298

9399
def get_offers(
94100
self, requirements: Optional[Requirements] = None
@@ -282,6 +288,7 @@ def create_instance(
282288
network=self.config.vpc_resource_name,
283289
subnetwork=subnetwork,
284290
allocate_public_ip=allocate_public_ip,
291+
placement_policy=f"projects/{self.config.project_id}/regions/{instance_offer.region}/resourcePolicies/{instance_config.placement_group_name}",
285292
)
286293
try:
287294
# GCP needs some time to return an error in case of no capacity (< 30s).
@@ -374,6 +381,39 @@ def update_provisioning_data(
374381
f"Failed to get instance IP address. Instance status: {instance.status}"
375382
)
376383

384+
def create_placement_group(
385+
self,
386+
placement_group: PlacementGroup,
387+
) -> PlacementGroupProvisioningData:
388+
policy = compute_v1.ResourcePolicy(
389+
name=placement_group.name,
390+
region=placement_group.configuration.region,
391+
group_placement_policy=compute_v1.ResourcePolicyGroupPlacementPolicy(
392+
availability_domain_count=1,
393+
collocation="COLLOCATED",
394+
),
395+
)
396+
self.resource_policies_client.insert(
397+
project=self.config.project_id,
398+
region=placement_group.configuration.region,
399+
resource_policy_resource=policy,
400+
)
401+
return PlacementGroupProvisioningData(backend=BackendType.GCP)
402+
403+
def delete_placement_group(
404+
self,
405+
placement_group: PlacementGroup,
406+
):
407+
try:
408+
operation = self.resource_policies_client.delete(
409+
project=self.config.project_id,
410+
region=placement_group.configuration.region,
411+
resource_policy=placement_group.name,
412+
)
413+
operation.result() # Wait for operation to complete
414+
except google.api_core.exceptions.NotFound:
415+
logger.debug("Placement group %s not found", placement_group.name)
416+
377417
def create_gateway(
378418
self,
379419
configuration: GatewayComputeConfiguration,

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def create_instance_struct(
117117
network: str = "global/networks/default",
118118
subnetwork: Optional[str] = None,
119119
allocate_public_ip: bool = True,
120+
placement_policy: Optional[str] = None,
120121
) -> compute_v1.Instance:
121122
instance = compute_v1.Instance()
122123
instance.name = instance_name
@@ -149,6 +150,9 @@ def create_instance_struct(
149150
# Attachable GPUs, H100, A100, and L4
150151
instance.scheduling.on_host_maintenance = "TERMINATE"
151152

153+
if placement_policy is not None:
154+
instance.resource_policies = [placement_policy]
155+
152156
if spot:
153157
instance.scheduling = compute_v1.Scheduling()
154158
instance.scheduling.provisioning_model = compute_v1.Scheduling.ProvisioningModel.SPOT.name

0 commit comments

Comments
 (0)