Skip to content

Commit 8ad206d

Browse files
committed
Add extra_vpcs to GCP config
1 parent d5d37dc commit 8ad206d

File tree

5 files changed

+75
-22
lines changed

5 files changed

+75
-22
lines changed

.github/workflows/gcp-a3mega-image.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ name: Build GCP A3 Mega VM image
22

33
on:
44
- workflow_dispatch
5-
- push
65

76
env:
87
PACKER_VERSION: "1.9.2"

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import concurrent.futures
22
import json
33
from collections import defaultdict
4-
from typing import Callable, Dict, List, Literal, Optional
4+
from typing import Callable, Dict, List, Literal, Optional, Tuple
55

66
import google.api_core.exceptions
77
import google.cloud.compute_v1 as compute_v1
@@ -192,6 +192,12 @@ def create_instance(
192192
config=self.config,
193193
region=instance_offer.region,
194194
)
195+
extra_subnets = _get_extra_subnets(
196+
subnetworks_client=self.subnetworks_client,
197+
config=self.config,
198+
region=instance_offer.region,
199+
instance_type_name=instance_offer.instance.name,
200+
)
195201
placement_policy = None
196202
if instance_config.placement_group_name is not None:
197203
placement_policy = gcp_resources.get_placement_policy_resource_name(
@@ -300,6 +306,7 @@ def create_instance(
300306
service_account=self.config.vm_service_account,
301307
network=self.config.vpc_resource_name,
302308
subnetwork=subnetwork,
309+
extra_subnetworks=extra_subnets,
303310
allocate_public_ip=allocate_public_ip,
304311
placement_policy=placement_policy,
305312
)
@@ -741,21 +748,6 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
741748
)
742749

743750

744-
def _get_vpc_subnet(
745-
subnetworks_client: compute_v1.SubnetworksClient,
746-
config: GCPConfig,
747-
region: str,
748-
) -> Optional[str]:
749-
if config.vpc_name is None:
750-
return None
751-
return gcp_resources.get_vpc_subnet_or_error(
752-
subnetworks_client=subnetworks_client,
753-
vpc_project_id=config.vpc_project_id or config.project_id,
754-
vpc_name=config.vpc_name,
755-
region=region,
756-
)
757-
758-
759751
def _supported_instances_and_zones(
760752
regions: List[str],
761753
) -> Optional[Callable[[InstanceOffer], bool]]:
@@ -814,6 +806,47 @@ def _unique_instance_name(instance: InstanceType) -> str:
814806
return f"{name}-{gpu.name}-{gpu.memory_mib}"
815807

816808

809+
def _get_vpc_subnet(
810+
subnetworks_client: compute_v1.SubnetworksClient,
811+
config: GCPConfig,
812+
region: str,
813+
) -> Optional[str]:
814+
if config.vpc_name is None:
815+
return None
816+
return gcp_resources.get_vpc_subnet_or_error(
817+
subnetworks_client=subnetworks_client,
818+
vpc_project_id=config.vpc_project_id or config.project_id,
819+
vpc_name=config.vpc_name,
820+
region=region,
821+
)
822+
823+
824+
def _get_extra_subnets(
825+
subnetworks_client: compute_v1.SubnetworksClient,
826+
config: GCPConfig,
827+
region: str,
828+
instance_type_name: str,
829+
) -> List[Tuple[str, str]]:
830+
if config.extra_vpcs is None:
831+
return []
832+
if instance_type_name != "a3-megagpu-8g":
833+
return []
834+
extra_subnets = []
835+
for vpc_name in config.extra_vpcs:
836+
subnet = gcp_resources.get_vpc_subnet_or_error(
837+
subnetworks_client=subnetworks_client,
838+
vpc_project_id=config.vpc_project_id or config.project_id,
839+
vpc_name=vpc_name,
840+
region=region,
841+
)
842+
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
843+
project_id=config.vpc_project_id or config.project_id,
844+
vpc_name=vpc_name,
845+
)
846+
extra_subnets.append((vpc_resource_name, subnet))
847+
return extra_subnets[:8]
848+
849+
817850
def _get_image_id(instance_type_name: str, cuda: bool) -> str:
818851
if instance_type_name == "a3-megagpu-8g":
819852
image_name = "dstack-a3mega-2"

src/dstack/_internal/core/backends/gcp/configurator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,5 @@ def _check_config_vpc(
199199
)
200200
except BackendError as e:
201201
raise ServerClientError(e.args[0])
202+
# Not checking config.extra_vpc so that users are not required to configure subnets for all regions
203+
# but only for regions they intend to use. Validation will be done on provisioning.

src/dstack/_internal/core/backends/gcp/models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,19 @@ class GCPBackendConfig(CoreModel):
3333
regions: Annotated[
3434
Optional[List[str]], Field(description="The list of GCP regions. Omit to use all regions")
3535
] = None
36-
vpc_name: Annotated[Optional[str], Field(description="The name of a custom VPC")] = None
36+
vpc_name: Annotated[
37+
Optional[str],
38+
Field(description="The name of a custom VPC. If not specified, the default VPC is used"),
39+
] = None
40+
extra_vpcs: Annotated[
41+
Optional[List[str]],
42+
Field(
43+
description=(
44+
"The names of additional VPCs used for GPUDirect. Specify eight VPCs to maximize bandwidth."
45+
" Each VPC must have a subnet and a firewall rule allowing internal traffic across all subnets"
46+
)
47+
),
48+
] = None
3749
vpc_project_id: Annotated[
3850
Optional[str],
3951
Field(description="The shared VPC hosted project ID. Required for shared VPC only"),

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import concurrent.futures
22
import re
3-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Optional, Tuple
44

55
import google.api_core.exceptions
66
import google.cloud.compute_v1 as compute_v1
@@ -116,6 +116,7 @@ def create_instance_struct(
116116
service_account: Optional[str] = None,
117117
network: str = "global/networks/default",
118118
subnetwork: Optional[str] = None,
119+
extra_subnetworks: Optional[List[Tuple[str, str]]] = None,
119120
allocate_public_ip: bool = True,
120121
placement_policy: Optional[str] = None,
121122
) -> compute_v1.Instance:
@@ -126,6 +127,7 @@ def create_instance_struct(
126127
network=network,
127128
subnetwork=subnetwork,
128129
allocate_public_ip=allocate_public_ip,
130+
extra_subnetworks=extra_subnetworks,
129131
)
130132

131133
disk = compute_v1.AttachedDisk()
@@ -184,6 +186,7 @@ def _get_network_interfaces(
184186
network: str,
185187
subnetwork: Optional[str],
186188
allocate_public_ip: bool,
189+
extra_subnetworks: Optional[List[Tuple[str, str]]],
187190
) -> List[compute_v1.NetworkInterface]:
188191
network_interface = compute_v1.NetworkInterface()
189192
network_interface.network = network
@@ -199,11 +202,11 @@ def _get_network_interfaces(
199202
network_interface.access_configs = []
200203

201204
network_interfaces = [network_interface]
202-
for i in range(1, 9):
205+
for network, subnetwork in extra_subnetworks or []:
203206
network_interfaces.append(
204207
compute_v1.NetworkInterface(
205-
network=f"projects/dstack/global/networks/dstack-test-data-net-{i}",
206-
subnetwork=f"projects/dstack/regions/europe-west4/subnetworks/dstack-test-data-sub-{i}",
208+
network=network,
209+
subnetwork=subnetwork,
207210
)
208211
)
209212
return network_interfaces
@@ -420,6 +423,10 @@ def full_resource_name_to_name(full_resource_name: str) -> str:
420423
return full_resource_name.split("/")[-1]
421424

422425

426+
def vpc_name_to_vpc_resource_name(project_id: str, vpc_name: str) -> str:
427+
return f"projects/{project_id}/global/networks/{vpc_name}"
428+
429+
423430
def get_placement_policy_resource_name(
424431
project_id: str,
425432
region: str,

0 commit comments

Comments
 (0)