Skip to content

Commit bc86d64

Browse files
committed
Handle placement group in use
1 parent 01190a0 commit bc86d64

File tree

2 files changed

+72
-52
lines changed

2 files changed

+72
-52
lines changed

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
ComputeError,
3232
ComputeResourceNotFoundError,
3333
NoCapacityError,
34+
PlacementGroupInUseError,
3435
ProvisioningError,
3536
)
3637
from dstack._internal.core.models.backends.base import BackendType
@@ -189,6 +190,13 @@ def create_instance(
189190
config=self.config,
190191
region=instance_offer.region,
191192
)
193+
placement_policy = None
194+
if instance_config.placement_group_name is not None:
195+
placement_policy = gcp_resources.get_placement_policy_resource_name(
196+
project_id=self.config.project_id,
197+
region=instance_offer.region,
198+
placement_policy=instance_config.placement_group_name,
199+
)
192200
labels = {
193201
"owner": "dstack",
194202
"dstack_project": instance_config.project_name.lower(),
@@ -288,7 +296,7 @@ def create_instance(
288296
network=self.config.vpc_resource_name,
289297
subnetwork=subnetwork,
290298
allocate_public_ip=allocate_public_ip,
291-
placement_policy=f"projects/{self.config.project_id}/regions/{instance_offer.region}/resourcePolicies/{instance_config.placement_group_name}",
299+
placement_policy=placement_policy,
292300
)
293301
try:
294302
# GCP needs some time to return an error in case of no capacity (< 30s).
@@ -413,6 +421,10 @@ def delete_placement_group(
413421
operation.result() # Wait for operation to complete
414422
except google.api_core.exceptions.NotFound:
415423
logger.debug("Placement group %s not found", placement_group.name)
424+
except google.api_core.exceptions.BadRequest as e:
425+
if "is already being used by" in e.message:
426+
raise PlacementGroupInUseError()
427+
raise
416428

417429
def create_gateway(
418430
self,
@@ -797,57 +809,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
797809
return f"{name}-{gpu.name}-{gpu.memory_mib}"
798810

799811

800-
def _get_tpu_startup_script(authorized_keys: List[str]) -> str:
801-
commands = get_shim_commands(
802-
authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU"
803-
)
804-
startup_script = " ".join([" && ".join(commands)])
805-
startup_script = "#! /bin/bash\n" + startup_script
806-
return startup_script
807-
808-
809-
def _is_tpu(instance_name: str) -> bool:
810-
parts = instance_name.split("-")
811-
if len(parts) == 2:
812-
version, cores = parts
813-
if version in TPU_VERSIONS and cores.isdigit():
814-
return True
815-
return False
816-
817-
818-
def _get_tpu_runtime_version(instance_name: str) -> str:
819-
tpu_version = _get_tpu_version(instance_name)
820-
if tpu_version == "v6e":
821-
return "v2-alpha-tpuv6e"
822-
elif tpu_version == "v5litepod":
823-
return "v2-alpha-tpuv5-lite"
824-
return "tpu-ubuntu2204-base"
825-
826-
827-
def _get_tpu_version(instance_name: str) -> str:
828-
return instance_name.split("-")[0]
829-
830-
831-
def _is_single_host_tpu(instance_name: str) -> bool:
832-
parts = instance_name.split("-")
833-
if len(parts) != 2:
834-
logger.info("Skipping unknown TPU: %s", instance_name)
835-
return False
836-
tpu_version, tensor_cores = parts
837-
try:
838-
tensor_cores = int(tensor_cores)
839-
except ValueError:
840-
logger.info("Skipping TPU due to invalid number of tensor cores: %s", tensor_cores)
841-
return False
842-
if tpu_version in ["v2", "v3", "v5p", "v5litepod", "v6e"]:
843-
return tensor_cores <= 8
844-
elif tpu_version == "v4":
845-
return False
846-
else:
847-
logger.info("Skipping unknown TPU: %s", instance_name)
848-
return False
849-
850-
851812
def _get_backend_specific_commands_tcpx() -> List[str]:
852813
return [
853814
"cos-extensions install gpu -- --version=latest",
@@ -916,6 +877,57 @@ def _get_volume_price(size: int) -> float:
916877
return size * 0.12
917878

918879

880+
def _get_tpu_startup_script(authorized_keys: List[str]) -> str:
881+
commands = get_shim_commands(
882+
authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU"
883+
)
884+
startup_script = " ".join([" && ".join(commands)])
885+
startup_script = "#! /bin/bash\n" + startup_script
886+
return startup_script
887+
888+
889+
def _is_tpu(instance_name: str) -> bool:
890+
parts = instance_name.split("-")
891+
if len(parts) == 2:
892+
version, cores = parts
893+
if version in TPU_VERSIONS and cores.isdigit():
894+
return True
895+
return False
896+
897+
898+
def _get_tpu_runtime_version(instance_name: str) -> str:
899+
tpu_version = _get_tpu_version(instance_name)
900+
if tpu_version == "v6e":
901+
return "v2-alpha-tpuv6e"
902+
elif tpu_version == "v5litepod":
903+
return "v2-alpha-tpuv5-lite"
904+
return "tpu-ubuntu2204-base"
905+
906+
907+
def _get_tpu_version(instance_name: str) -> str:
908+
return instance_name.split("-")[0]
909+
910+
911+
def _is_single_host_tpu(instance_name: str) -> bool:
912+
parts = instance_name.split("-")
913+
if len(parts) != 2:
914+
logger.info("Skipping unknown TPU: %s", instance_name)
915+
return False
916+
tpu_version, tensor_cores = parts
917+
try:
918+
tensor_cores = int(tensor_cores)
919+
except ValueError:
920+
logger.info("Skipping TPU due to invalid number of tensor cores: %s", tensor_cores)
921+
return False
922+
if tpu_version in ["v2", "v3", "v5p", "v5litepod", "v6e"]:
923+
return tensor_cores <= 8
924+
elif tpu_version == "v4":
925+
return False
926+
else:
927+
logger.info("Skipping unknown TPU: %s", instance_name)
928+
return False
929+
930+
919931
def _get_tpu_data_disks(
920932
project_id: str, volumes: Optional[List[Volume]]
921933
) -> List[tpu_v2.AttachedDisk]:

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,11 @@ def wait_for_operation(operation: Operation, verbose_name: str = "operation", ti
434434

435435
def full_resource_name_to_name(full_resource_name: str) -> str:
436436
return full_resource_name.split("/")[-1]
437+
438+
439+
def get_placement_policy_resource_name(
440+
project_id: str,
441+
region: str,
442+
placement_policy: str,
443+
) -> str:
444+
return f"projects/{project_id}/regions/{region}/resourcePolicies/{placement_policy}"

0 commit comments

Comments
 (0)