|
31 | 31 | ComputeError, |
32 | 32 | ComputeResourceNotFoundError, |
33 | 33 | NoCapacityError, |
| 34 | + PlacementGroupInUseError, |
34 | 35 | ProvisioningError, |
35 | 36 | ) |
36 | 37 | from dstack._internal.core.models.backends.base import BackendType |
@@ -189,6 +190,13 @@ def create_instance( |
189 | 190 | config=self.config, |
190 | 191 | region=instance_offer.region, |
191 | 192 | ) |
| 193 | + placement_policy = None |
| 194 | + if instance_config.placement_group_name is not None: |
| 195 | + placement_policy = gcp_resources.get_placement_policy_resource_name( |
| 196 | + project_id=self.config.project_id, |
| 197 | + region=instance_offer.region, |
| 198 | + placement_policy=instance_config.placement_group_name, |
| 199 | + ) |
192 | 200 | labels = { |
193 | 201 | "owner": "dstack", |
194 | 202 | "dstack_project": instance_config.project_name.lower(), |
@@ -288,7 +296,7 @@ def create_instance( |
288 | 296 | network=self.config.vpc_resource_name, |
289 | 297 | subnetwork=subnetwork, |
290 | 298 | allocate_public_ip=allocate_public_ip, |
291 | | - placement_policy=f"projects/{self.config.project_id}/regions/{instance_offer.region}/resourcePolicies/{instance_config.placement_group_name}", |
| 299 | + placement_policy=placement_policy, |
292 | 300 | ) |
293 | 301 | try: |
294 | 302 | # GCP needs some time to return an error in case of no capacity (< 30s). |
@@ -413,6 +421,10 @@ def delete_placement_group( |
413 | 421 | operation.result() # Wait for operation to complete |
414 | 422 | except google.api_core.exceptions.NotFound: |
415 | 423 | logger.debug("Placement group %s not found", placement_group.name) |
| 424 | + except google.api_core.exceptions.BadRequest as e: |
| 425 | + if "is already being used by" in e.message: |
| 426 | + raise PlacementGroupInUseError() |
| 427 | + raise |
416 | 428 |
|
417 | 429 | def create_gateway( |
418 | 430 | self, |
@@ -797,57 +809,6 @@ def _unique_instance_name(instance: InstanceType) -> str: |
797 | 809 | return f"{name}-{gpu.name}-{gpu.memory_mib}" |
798 | 810 |
|
799 | 811 |
|
800 | | -def _get_tpu_startup_script(authorized_keys: List[str]) -> str: |
801 | | - commands = get_shim_commands( |
802 | | - authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU" |
803 | | - ) |
804 | | - startup_script = " ".join([" && ".join(commands)]) |
805 | | - startup_script = "#! /bin/bash\n" + startup_script |
806 | | - return startup_script |
807 | | - |
808 | | - |
809 | | -def _is_tpu(instance_name: str) -> bool: |
810 | | - parts = instance_name.split("-") |
811 | | - if len(parts) == 2: |
812 | | - version, cores = parts |
813 | | - if version in TPU_VERSIONS and cores.isdigit(): |
814 | | - return True |
815 | | - return False |
816 | | - |
817 | | - |
818 | | -def _get_tpu_runtime_version(instance_name: str) -> str: |
819 | | - tpu_version = _get_tpu_version(instance_name) |
820 | | - if tpu_version == "v6e": |
821 | | - return "v2-alpha-tpuv6e" |
822 | | - elif tpu_version == "v5litepod": |
823 | | - return "v2-alpha-tpuv5-lite" |
824 | | - return "tpu-ubuntu2204-base" |
825 | | - |
826 | | - |
827 | | -def _get_tpu_version(instance_name: str) -> str: |
828 | | - return instance_name.split("-")[0] |
829 | | - |
830 | | - |
831 | | -def _is_single_host_tpu(instance_name: str) -> bool: |
832 | | - parts = instance_name.split("-") |
833 | | - if len(parts) != 2: |
834 | | - logger.info("Skipping unknown TPU: %s", instance_name) |
835 | | - return False |
836 | | - tpu_version, tensor_cores = parts |
837 | | - try: |
838 | | - tensor_cores = int(tensor_cores) |
839 | | - except ValueError: |
840 | | - logger.info("Skipping TPU due to invalid number of tensor cores: %s", tensor_cores) |
841 | | - return False |
842 | | - if tpu_version in ["v2", "v3", "v5p", "v5litepod", "v6e"]: |
843 | | - return tensor_cores <= 8 |
844 | | - elif tpu_version == "v4": |
845 | | - return False |
846 | | - else: |
847 | | - logger.info("Skipping unknown TPU: %s", instance_name) |
848 | | - return False |
849 | | - |
850 | | - |
851 | 812 | def _get_backend_specific_commands_tcpx() -> List[str]: |
852 | 813 | return [ |
853 | 814 | "cos-extensions install gpu -- --version=latest", |
@@ -916,6 +877,57 @@ def _get_volume_price(size: int) -> float: |
916 | 877 | return size * 0.12 |
917 | 878 |
|
918 | 879 |
|
| 880 | +def _get_tpu_startup_script(authorized_keys: List[str]) -> str: |
| 881 | + commands = get_shim_commands( |
| 882 | + authorized_keys=authorized_keys, is_privileged=True, pjrt_device="TPU" |
| 883 | + ) |
| 884 | + startup_script = " ".join([" && ".join(commands)]) |
| 885 | + startup_script = "#! /bin/bash\n" + startup_script |
| 886 | + return startup_script |
| 887 | + |
| 888 | + |
| 889 | +def _is_tpu(instance_name: str) -> bool: |
| 890 | + parts = instance_name.split("-") |
| 891 | + if len(parts) == 2: |
| 892 | + version, cores = parts |
| 893 | + if version in TPU_VERSIONS and cores.isdigit(): |
| 894 | + return True |
| 895 | + return False |
| 896 | + |
| 897 | + |
| 898 | +def _get_tpu_runtime_version(instance_name: str) -> str: |
| 899 | + tpu_version = _get_tpu_version(instance_name) |
| 900 | + if tpu_version == "v6e": |
| 901 | + return "v2-alpha-tpuv6e" |
| 902 | + elif tpu_version == "v5litepod": |
| 903 | + return "v2-alpha-tpuv5-lite" |
| 904 | + return "tpu-ubuntu2204-base" |
| 905 | + |
| 906 | + |
| 907 | +def _get_tpu_version(instance_name: str) -> str: |
| 908 | + return instance_name.split("-")[0] |
| 909 | + |
| 910 | + |
| 911 | +def _is_single_host_tpu(instance_name: str) -> bool: |
| 912 | + parts = instance_name.split("-") |
| 913 | + if len(parts) != 2: |
| 914 | + logger.info("Skipping unknown TPU: %s", instance_name) |
| 915 | + return False |
| 916 | + tpu_version, tensor_cores = parts |
| 917 | + try: |
| 918 | + tensor_cores = int(tensor_cores) |
| 919 | + except ValueError: |
| 920 | + logger.info("Skipping TPU due to invalid number of tensor cores: %s", tensor_cores) |
| 921 | + return False |
| 922 | + if tpu_version in ["v2", "v3", "v5p", "v5litepod", "v6e"]: |
| 923 | + return tensor_cores <= 8 |
| 924 | + elif tpu_version == "v4": |
| 925 | + return False |
| 926 | + else: |
| 927 | + logger.info("Skipping unknown TPU: %s", instance_name) |
| 928 | + return False |
| 929 | + |
| 930 | + |
919 | 931 | def _get_tpu_data_disks( |
920 | 932 | project_id: str, volumes: Optional[List[Volume]] |
921 | 933 | ) -> List[tpu_v2.AttachedDisk]: |
|
0 commit comments