|
1 | 1 | import concurrent.futures |
2 | 2 | import json |
| 3 | +import threading |
3 | 4 | from collections import defaultdict |
4 | 5 | from typing import Callable, Dict, List, Literal, Optional, Tuple |
5 | 6 |
|
6 | 7 | import google.api_core.exceptions |
7 | 8 | import google.cloud.compute_v1 as compute_v1 |
| 9 | +from cachetools import TTLCache, cachedmethod |
8 | 10 | from google.cloud import tpu_v2 |
9 | 11 | from gpuhunt import KNOWN_TPUS |
10 | 12 |
|
@@ -98,6 +100,8 @@ def __init__(self, config: GCPConfig): |
98 | 100 | self.resource_policies_client = compute_v1.ResourcePoliciesClient( |
99 | 101 | credentials=self.credentials |
100 | 102 | ) |
| 103 | + self._extra_subnets_cache_lock = threading.Lock() |
| 104 | + self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60) |
101 | 105 |
|
102 | 106 | def get_offers( |
103 | 107 | self, requirements: Optional[Requirements] = None |
@@ -193,9 +197,7 @@ def create_instance( |
193 | 197 | config=self.config, |
194 | 198 | region=instance_offer.region, |
195 | 199 | ) |
196 | | - extra_subnets = _get_extra_subnets( |
197 | | - subnetworks_client=self.subnetworks_client, |
198 | | - config=self.config, |
| 200 | + extra_subnets = self._get_extra_subnets( |
199 | 201 | region=instance_offer.region, |
200 | 202 | instance_type_name=instance_offer.instance.name, |
201 | 203 | ) |
@@ -769,6 +771,38 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): |
769 | 771 | instance_id, |
770 | 772 | ) |
771 | 773 |
|
| 774 | + @cachedmethod( |
| 775 | + cache=lambda self: self._extra_subnets_cache, |
| 776 | + lock=lambda self: self._extra_subnets_cache_lock, |
| 777 | + ) |
| 778 | + def _get_extra_subnets( |
| 779 | + self, |
| 780 | + region: str, |
| 781 | + instance_type_name: str, |
| 782 | + ) -> List[Tuple[str, str]]: |
| 783 | + if self.config.extra_vpcs is None: |
| 784 | + return [] |
| 785 | + if instance_type_name == "a3-megagpu-8g": |
| 786 | + subnets_num = 8 |
| 787 | + elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: |
| 788 | + subnets_num = 4 |
| 789 | + else: |
| 790 | + return [] |
| 791 | + extra_subnets = [] |
| 792 | + for vpc_name in self.config.extra_vpcs[:subnets_num]: |
| 793 | + subnet = gcp_resources.get_vpc_subnet_or_error( |
| 794 | + subnetworks_client=self.subnetworks_client, |
| 795 | + vpc_project_id=self.config.vpc_project_id or self.config.project_id, |
| 796 | + vpc_name=vpc_name, |
| 797 | + region=region, |
| 798 | + ) |
| 799 | + vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( |
| 800 | + project_id=self.config.vpc_project_id or self.config.project_id, |
| 801 | + vpc_name=vpc_name, |
| 802 | + ) |
| 803 | + extra_subnets.append((vpc_resource_name, subnet)) |
| 804 | + return extra_subnets |
| 805 | + |
772 | 806 |
|
773 | 807 | def _supported_instances_and_zones( |
774 | 808 | regions: List[str], |
@@ -843,36 +877,6 @@ def _get_vpc_subnet( |
843 | 877 | ) |
844 | 878 |
|
845 | 879 |
|
846 | | -def _get_extra_subnets( |
847 | | - subnetworks_client: compute_v1.SubnetworksClient, |
848 | | - config: GCPConfig, |
849 | | - region: str, |
850 | | - instance_type_name: str, |
851 | | -) -> List[Tuple[str, str]]: |
852 | | - if config.extra_vpcs is None: |
853 | | - return [] |
854 | | - if instance_type_name == "a3-megagpu-8g": |
855 | | - subnets_num = 8 |
856 | | - elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: |
857 | | - subnets_num = 4 |
858 | | - else: |
859 | | - return [] |
860 | | - extra_subnets = [] |
861 | | - for vpc_name in config.extra_vpcs[:subnets_num]: |
862 | | - subnet = gcp_resources.get_vpc_subnet_or_error( |
863 | | - subnetworks_client=subnetworks_client, |
864 | | - vpc_project_id=config.vpc_project_id or config.project_id, |
865 | | - vpc_name=vpc_name, |
866 | | - region=region, |
867 | | - ) |
868 | | - vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( |
869 | | - project_id=config.vpc_project_id or config.project_id, |
870 | | - vpc_name=vpc_name, |
871 | | - ) |
872 | | - extra_subnets.append((vpc_resource_name, subnet)) |
873 | | - return extra_subnets |
874 | | - |
875 | | - |
876 | 880 | def _get_image_id(instance_type_name: str, cuda: bool) -> str: |
877 | 881 | if instance_type_name == "a3-megagpu-8g": |
878 | 882 | image_name = "dstack-a3mega-5" |
|
0 commit comments