Skip to content

Commit d6e5569

Browse files
committed
WIP: support GCP TCPXO
1 parent 8905c8c commit d6e5569

File tree

3 files changed

+112
-27
lines changed

3 files changed

+112
-27
lines changed

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939

4040
logger = get_logger(__name__)
4141

42-
DSTACK_WORKING_DIR = "/root/.dstack"
42+
DSTACK_WORKING_DIR = "/etc/.dstack"
4343
DSTACK_SHIM_BINARY_NAME = "dstack-shim"
44-
DSTACK_SHIM_BINARY_PATH = f"/usr/local/bin/{DSTACK_SHIM_BINARY_NAME}"
44+
DSTACK_SHIM_BINARY_PATH = f"/etc/{DSTACK_SHIM_BINARY_NAME}"
4545
DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46-
DSTACK_RUNNER_BINARY_PATH = f"/usr/local/bin/{DSTACK_RUNNER_BINARY_NAME}"
46+
DSTACK_RUNNER_BINARY_PATH = f"/etc/{DSTACK_RUNNER_BINARY_NAME}"
4747

4848

4949
class Compute(ABC):
@@ -525,7 +525,7 @@ def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List
525525
pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else ""
526526

527527
return [
528-
f"nohup dstack-shim {privileged_flag} {pjrt_device_env} &",
528+
f"nohup {DSTACK_SHIM_BINARY_PATH} {privileged_flag} {pjrt_device_env} &",
529529
]
530530

531531

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,10 @@ def create_instance(
269269
gpus=instance_offer.instance.resources.gpus,
270270
),
271271
spot=instance_offer.instance.resources.spot,
272-
user_data=get_user_data(authorized_keys),
272+
user_data=get_user_data(
273+
authorized_keys,
274+
backend_specific_commands=_get_backend_specific_commands_tcpxo(),
275+
),
273276
authorized_keys=authorized_keys,
274277
labels=labels,
275278
tags=[gcp_resources.DSTACK_INSTANCE_TAG],
@@ -805,6 +808,68 @@ def _is_single_host_tpu(instance_name: str) -> bool:
805808
return False
806809

807810

811+
def _get_backend_specific_commands_tcpx() -> List[str]:
812+
return [
813+
"cos-extensions install gpu -- --version=latest",
814+
"sudo mount --bind /var/lib/nvidia /var/lib/nvidia",
815+
"sudo mount -o remount,exec /var/lib/nvidia",
816+
(
817+
"docker run "
818+
"--detach "
819+
"--pull=always"
820+
"--name receive-datapath-manager "
821+
"--privileged "
822+
"--cap-add=NET_ADMIN --network=host "
823+
"--volume /var/lib/nvidia/lib64:/usr/local/nvidia/lib64 "
824+
"--device /dev/nvidia0:/dev/nvidia0 --device /dev/nvidia1:/dev/nvidia1 "
825+
"--device /dev/nvidia2:/dev/nvidia2 --device /dev/nvidia3:/dev/nvidia3 "
826+
"--device /dev/nvidia4:/dev/nvidia4 --device /dev/nvidia5:/dev/nvidia5 "
827+
"--device /dev/nvidia6:/dev/nvidia6 --device /dev/nvidia7:/dev/nvidia7 "
828+
"--device /dev/nvidia-uvm:/dev/nvidia-uvm --device /dev/nvidiactl:/dev/nvidiactl "
829+
"--env LD_LIBRARY_PATH=/usr/local/nvidia/lib64 "
830+
"--volume /run/tcpx:/run/tcpx "
831+
"--entrypoint /tcpgpudmarxd/build/app/tcpgpudmarxd "
832+
"us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd "
833+
'--gpu_nic_preset a3vm --gpu_shmem_type fd --uds_path "/run/tcpx" --setup_param "--verbose 128 2 0"'
834+
),
835+
"sudo iptables -I INPUT -p tcp -m tcp -j ACCEPT",
836+
"docker run --rm -v /var/lib:/var/lib us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/nccl-plugin-gpudirecttcpx install --install-nccl",
837+
"sudo mount --bind /var/lib/tcpx /var/lib/tcpx",
838+
"sudo mount -o remount,exec /var/lib/tcpx",
839+
]
840+
841+
842+
def _get_backend_specific_commands_tcpxo() -> List[str]:
843+
return [
844+
"modprobe import-helper",
845+
"gcloud -q auth configure-docker us-docker.pkg.dev",
846+
# Install the nccl, nccl-net lib into /var/lib/tcpxo/lib64/.
847+
(
848+
"docker run --rm --name nccl-installer "
849+
"--network=host "
850+
"--volume /var/lib:/var/lib "
851+
"us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/nccl-plugin-gpudirecttcpx-dev:v1.0.8-1 "
852+
"install --install-nccl"
853+
),
854+
# Start FasTrak receive-datapath-manager
855+
(
856+
"docker run "
857+
"--detach "
858+
"--pull=always "
859+
"--name receive-datapath-manager "
860+
"--cap-add=NET_ADMIN "
861+
"--network=host "
862+
"--privileged "
863+
"--gpus all "
864+
"--volume /usr/lib32:/usr/local/nvidia/lib64 "
865+
"--volume /dev/dmabuf_import_helper:/dev/dmabuf_import_helper "
866+
"--env LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu "
867+
"us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.14 "
868+
"--num_hops=2 --num_nics=8 --uid= --alsologtostderr"
869+
),
870+
]
871+
872+
808873
def _get_volume_price(size: int) -> float:
809874
# https://cloud.google.com/compute/disks-image-pricing#persistentdisk
810875
# The price is different in different regions. Take max across supported regions.

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from google.api_core.operation import Operation
99
from google.cloud import tpu_v2
1010

11-
import dstack.version as version
1211
from dstack._internal.core.errors import BackendError, ComputeError
1312
from dstack._internal.core.models.instances import Gpu
1413
from dstack._internal.utils.common import remove_prefix
@@ -119,24 +118,14 @@ def create_instance_struct(
119118
subnetwork: Optional[str] = None,
120119
allocate_public_ip: bool = True,
121120
) -> compute_v1.Instance:
122-
network_interface = compute_v1.NetworkInterface()
123-
network_interface.network = network
124-
if subnetwork is not None:
125-
network_interface.subnetwork = subnetwork
126-
127-
if allocate_public_ip:
128-
access = compute_v1.AccessConfig()
129-
access.type_ = compute_v1.AccessConfig.Type.ONE_TO_ONE_NAT.name
130-
access.name = "External NAT"
131-
access.network_tier = access.NetworkTier.PREMIUM.name
132-
network_interface.access_configs = [access]
133-
else:
134-
network_interface.access_configs = []
135-
136121
instance = compute_v1.Instance()
137-
instance.network_interfaces = [network_interface]
138122
instance.name = instance_name
139123
instance.machine_type = f"zones/{zone}/machineTypes/{machine_type}"
124+
instance.network_interfaces = _get_network_interfaces(
125+
network=network,
126+
subnetwork=subnetwork,
127+
allocate_public_ip=allocate_public_ip,
128+
)
140129

141130
disk = compute_v1.AttachedDisk()
142131
disk.auto_delete = True
@@ -187,14 +176,45 @@ def create_instance_struct(
187176
return instance
188177

189178

190-
def get_image_id(cuda: bool) -> str:
191-
if not cuda:
192-
image_name = f"dstack-{version.base_image}"
179+
def _get_network_interfaces(
180+
network: str,
181+
subnetwork: Optional[str],
182+
allocate_public_ip: bool,
183+
) -> List[compute_v1.NetworkInterface]:
184+
network_interface = compute_v1.NetworkInterface()
185+
network_interface.network = network
186+
if subnetwork is not None:
187+
network_interface.subnetwork = subnetwork
188+
if allocate_public_ip:
189+
access = compute_v1.AccessConfig()
190+
access.type_ = compute_v1.AccessConfig.Type.ONE_TO_ONE_NAT.name
191+
access.name = "External NAT"
192+
access.network_tier = access.NetworkTier.PREMIUM.name
193+
network_interface.access_configs = [access]
193194
else:
194-
image_name = f"dstack-cuda-{version.base_image}"
195-
image_name = image_name.replace(".", "-")
195+
network_interface.access_configs = []
196+
197+
network_interfaces = [network_interface]
198+
for i in range(1, 9):
199+
network_interfaces.append(
200+
compute_v1.NetworkInterface(
201+
network=f"projects/dstack/global/networks/dstack-test-data-net-{i}",
202+
subnetwork=f"projects/dstack/regions/europe-west4/subnetworks/dstack-test-data-sub-{i}",
203+
)
204+
)
205+
return network_interfaces
196206

197-
return f"projects/dstack/global/images/{image_name}"
207+
208+
def get_image_id(cuda: bool) -> str:
209+
# if not cuda:
210+
# image_name = f"dstack-{version.base_image}"
211+
# else:
212+
# image_name = f"dstack-cuda-{version.base_image}"
213+
# image_name = image_name.replace(".", "-")
214+
215+
# return f"projects/dstack/global/images/{image_name}"
216+
# return "projects/cos-cloud/global/images/cos-105-17412-535-78" # TCPX
217+
return "projects/dstack/global/images/slurm-a3mega-20250327t101736z-cloudinit" # TCPXO
198218

199219

200220
def get_gateway_image_id() -> str:

0 commit comments

Comments
 (0)