Skip to content

Commit 2e3da2c

Browse files
authored
Fix API quota hitting when provisioning many A3 instances (#2610)
1 parent f00aeae commit 2e3da2c

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import concurrent.futures
22
import json
3+
import threading
34
from collections import defaultdict
45
from typing import Callable, Dict, List, Literal, Optional, Tuple
56

67
import google.api_core.exceptions
78
import google.cloud.compute_v1 as compute_v1
9+
from cachetools import TTLCache, cachedmethod
810
from google.cloud import tpu_v2
911
from gpuhunt import KNOWN_TPUS
1012

@@ -98,6 +100,8 @@ def __init__(self, config: GCPConfig):
98100
self.resource_policies_client = compute_v1.ResourcePoliciesClient(
99101
credentials=self.credentials
100102
)
103+
self._extra_subnets_cache_lock = threading.Lock()
104+
self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60)
101105

102106
def get_offers(
103107
self, requirements: Optional[Requirements] = None
@@ -193,9 +197,7 @@ def create_instance(
193197
config=self.config,
194198
region=instance_offer.region,
195199
)
196-
extra_subnets = _get_extra_subnets(
197-
subnetworks_client=self.subnetworks_client,
198-
config=self.config,
200+
extra_subnets = self._get_extra_subnets(
199201
region=instance_offer.region,
200202
instance_type_name=instance_offer.instance.name,
201203
)
@@ -769,6 +771,38 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
769771
instance_id,
770772
)
771773

774+
@cachedmethod(
775+
cache=lambda self: self._extra_subnets_cache,
776+
lock=lambda self: self._extra_subnets_cache_lock,
777+
)
778+
def _get_extra_subnets(
779+
self,
780+
region: str,
781+
instance_type_name: str,
782+
) -> List[Tuple[str, str]]:
783+
if self.config.extra_vpcs is None:
784+
return []
785+
if instance_type_name == "a3-megagpu-8g":
786+
subnets_num = 8
787+
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
788+
subnets_num = 4
789+
else:
790+
return []
791+
extra_subnets = []
792+
for vpc_name in self.config.extra_vpcs[:subnets_num]:
793+
subnet = gcp_resources.get_vpc_subnet_or_error(
794+
subnetworks_client=self.subnetworks_client,
795+
vpc_project_id=self.config.vpc_project_id or self.config.project_id,
796+
vpc_name=vpc_name,
797+
region=region,
798+
)
799+
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
800+
project_id=self.config.vpc_project_id or self.config.project_id,
801+
vpc_name=vpc_name,
802+
)
803+
extra_subnets.append((vpc_resource_name, subnet))
804+
return extra_subnets
805+
772806

773807
def _supported_instances_and_zones(
774808
regions: List[str],
@@ -843,36 +877,6 @@ def _get_vpc_subnet(
843877
)
844878

845879

846-
def _get_extra_subnets(
847-
subnetworks_client: compute_v1.SubnetworksClient,
848-
config: GCPConfig,
849-
region: str,
850-
instance_type_name: str,
851-
) -> List[Tuple[str, str]]:
852-
if config.extra_vpcs is None:
853-
return []
854-
if instance_type_name == "a3-megagpu-8g":
855-
subnets_num = 8
856-
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
857-
subnets_num = 4
858-
else:
859-
return []
860-
extra_subnets = []
861-
for vpc_name in config.extra_vpcs[:subnets_num]:
862-
subnet = gcp_resources.get_vpc_subnet_or_error(
863-
subnetworks_client=subnetworks_client,
864-
vpc_project_id=config.vpc_project_id or config.project_id,
865-
vpc_name=vpc_name,
866-
region=region,
867-
)
868-
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
869-
project_id=config.vpc_project_id or config.project_id,
870-
vpc_name=vpc_name,
871-
)
872-
extra_subnets.append((vpc_resource_name, subnet))
873-
return extra_subnets
874-
875-
876880
def _get_image_id(instance_type_name: str, cuda: bool) -> str:
877881
if instance_type_name == "a3-megagpu-8g":
878882
image_name = "dstack-a3mega-5"

0 commit comments

Comments
 (0)