Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions sample_workloads/jax-tcpx-base-image/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
FROM ubuntu:22.04

ARG VERSION=0.4.21

# Install python and pip
RUN apt-get update && apt-get install -y --no-install-recommends \
python3 \
python3-pip && \
rm -rf /var/lib/apt/lists/*

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parambole not sure if this pip upgrade is really needed?


RUN pip install --no-cache-dir --upgrade "jax[cuda12_pip]==${VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

WORKDIR /workspace/

ADD entrypoint.sh /entrypoint.sh

ENTRYPOINT ["bash", "/entrypoint.sh"]

# Environment variables required by TCPX
ENV NCCL_NVLS_ENABLE=0
ENV NCCL_CROSS_NIC=0
ENV NCCL_ALGO=Ring
ENV NCCL_PROTO=Simple
ENV NCCL_DEBUG=INFO
ENV NCCL_NET_GDR_LEVEL=PIX
ENV NCCL_P2P_PXN_LEVEL=0
ENV NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
ENV NCCL_GPUDIRECTTCPX_FORCE_ACK=0
ENV NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
ENV NCCL_DYNAMIC_CHUNK_SIZE=524288
ENV NCCL_P2P_NET_CHUNKSIZE=524288
ENV NCCL_P2P_PCI_CHUNKSIZE=524288
ENV NCCL_P2P_NVL_CHUNKSIZE=1048576
ENV NCCL_NSOCKS_PERTHREAD=4
ENV NCCL_SOCKET_NTHREADS=1
ENV NCCL_MAX_NCHANNELS=12
ENV NCCL_MIN_NCHANNELS=12
ENV NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
ENV NCCL_SOCKET_IFNAME=eth0
ENV NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
ENV NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
ENV NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
ENV NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
# Might require adjusting based on where the TCPX socket is mounted
ENV NCCL_GPUDIRECTTCPX_UNIX_CLIENT_PREFIX=/run/tcpx
14 changes: 14 additions & 0 deletions sample_workloads/jax-tcpx-base-image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Base image with JAX and TCPX config

A base image to use with Jax and optimized TCPX config

Image location:
```
us-docker.pkg.dev/$PROJECT_ID/jax-gpu/base-tcpx:0.4.21
```

## Pushing new image
```
gcloud builds submit --config=cloudbuild.yaml \
--substitutions=_VERSION=0.4.21 --project gce-ai-infra
```
6 changes: 6 additions & 0 deletions sample_workloads/jax-tcpx-base-image/cloudbuild.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Build and push image to Artifact Registry
steps:
- name: 'gcr.io/cloud-builders/docker'
args: [ 'build', '--build-arg', 'VERSION=${_VERSION}', '-t', 'us-docker.pkg.dev/$PROJECT_ID/jax-gpu/base-tcpx:${_VERSION}', '.' ]
images:
- 'us-docker.pkg.dev/$PROJECT_ID/jax-gpu/base-tcpx:${_VERSION}'
11 changes: 11 additions & 0 deletions sample_workloads/jax-tcpx-base-image/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env bash

set -x

function on_script_completion {
# semaphore to cleanly exit hardware utilization monitor
touch /run/tcpx/workload_terminated
}
trap on_script_completion EXIT

exec "$@"