|
1 | 1 | import concurrent.futures |
2 | 2 | import json |
3 | 3 | from collections import defaultdict |
4 | | -from typing import Callable, Dict, List, Literal, Optional |
| 4 | +from typing import Callable, Dict, List, Literal, Optional, Tuple |
5 | 5 |
|
6 | 6 | import google.api_core.exceptions |
7 | 7 | import google.cloud.compute_v1 as compute_v1 |
@@ -192,6 +192,12 @@ def create_instance( |
192 | 192 | config=self.config, |
193 | 193 | region=instance_offer.region, |
194 | 194 | ) |
| 195 | + extra_subnets = _get_extra_subnets( |
| 196 | + subnetworks_client=self.subnetworks_client, |
| 197 | + config=self.config, |
| 198 | + region=instance_offer.region, |
| 199 | + instance_type_name=instance_offer.instance.name, |
| 200 | + ) |
195 | 201 | placement_policy = None |
196 | 202 | if instance_config.placement_group_name is not None: |
197 | 203 | placement_policy = gcp_resources.get_placement_policy_resource_name( |
@@ -300,6 +306,7 @@ def create_instance( |
300 | 306 | service_account=self.config.vm_service_account, |
301 | 307 | network=self.config.vpc_resource_name, |
302 | 308 | subnetwork=subnetwork, |
| 309 | + extra_subnetworks=extra_subnets, |
303 | 310 | allocate_public_ip=allocate_public_ip, |
304 | 311 | placement_policy=placement_policy, |
305 | 312 | ) |
@@ -741,21 +748,6 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): |
741 | 748 | ) |
742 | 749 |
|
743 | 750 |
|
744 | | -def _get_vpc_subnet( |
745 | | - subnetworks_client: compute_v1.SubnetworksClient, |
746 | | - config: GCPConfig, |
747 | | - region: str, |
748 | | -) -> Optional[str]: |
749 | | - if config.vpc_name is None: |
750 | | - return None |
751 | | - return gcp_resources.get_vpc_subnet_or_error( |
752 | | - subnetworks_client=subnetworks_client, |
753 | | - vpc_project_id=config.vpc_project_id or config.project_id, |
754 | | - vpc_name=config.vpc_name, |
755 | | - region=region, |
756 | | - ) |
757 | | - |
758 | | - |
759 | 751 | def _supported_instances_and_zones( |
760 | 752 | regions: List[str], |
761 | 753 | ) -> Optional[Callable[[InstanceOffer], bool]]: |
@@ -814,6 +806,47 @@ def _unique_instance_name(instance: InstanceType) -> str: |
814 | 806 | return f"{name}-{gpu.name}-{gpu.memory_mib}" |
815 | 807 |
|
816 | 808 |
|
| 809 | +def _get_vpc_subnet( |
| 810 | + subnetworks_client: compute_v1.SubnetworksClient, |
| 811 | + config: GCPConfig, |
| 812 | + region: str, |
| 813 | +) -> Optional[str]: |
| 814 | + if config.vpc_name is None: |
| 815 | + return None |
| 816 | + return gcp_resources.get_vpc_subnet_or_error( |
| 817 | + subnetworks_client=subnetworks_client, |
| 818 | + vpc_project_id=config.vpc_project_id or config.project_id, |
| 819 | + vpc_name=config.vpc_name, |
| 820 | + region=region, |
| 821 | + ) |
| 822 | + |
| 823 | + |
| 824 | +def _get_extra_subnets( |
| 825 | + subnetworks_client: compute_v1.SubnetworksClient, |
| 826 | + config: GCPConfig, |
| 827 | + region: str, |
| 828 | + instance_type_name: str, |
| 829 | +) -> List[Tuple[str, str]]: |
| 830 | + if config.extra_vpcs is None: |
| 831 | + return [] |
| 832 | + if instance_type_name != "a3-megagpu-8g": |
| 833 | + return [] |
| 834 | + extra_subnets = [] |
| 835 | + for vpc_name in config.extra_vpcs: |
| 836 | + subnet = gcp_resources.get_vpc_subnet_or_error( |
| 837 | + subnetworks_client=subnetworks_client, |
| 838 | + vpc_project_id=config.vpc_project_id or config.project_id, |
| 839 | + vpc_name=vpc_name, |
| 840 | + region=region, |
| 841 | + ) |
| 842 | + vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( |
| 843 | + project_id=config.vpc_project_id or config.project_id, |
| 844 | + vpc_name=vpc_name, |
| 845 | + ) |
| 846 | + extra_subnets.append((vpc_resource_name, subnet)) |
| 847 | + return extra_subnets[:8] |
| 848 | + |
| 849 | + |
817 | 850 | def _get_image_id(instance_type_name: str, cuda: bool) -> str: |
818 | 851 | if instance_type_name == "a3-megagpu-8g": |
819 | 852 | image_name = "dstack-a3mega-2" |
|
0 commit comments