diff --git a/sample_workloads/jax-tcpx-base-image/Dockerfile b/sample_workloads/jax-tcpx-base-image/Dockerfile new file mode 100644 index 00000000..7c56a406 --- /dev/null +++ b/sample_workloads/jax-tcpx-base-image/Dockerfile @@ -0,0 +1,47 @@ +FROM ubuntu:22.04 + +ARG VERSION=0.4.21 + +# Install python and pip +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip && \ + rm -rf /var/lib/apt/lists/* + + +RUN pip install --no-cache-dir --upgrade "jax[cuda12_pip]==${VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +WORKDIR /workspace/ + +ADD entrypoint.sh /entrypoint.sh + +ENTRYPOINT ["bash", "/entrypoint.sh"] + +# Environment variables required by TCPX +ENV NCCL_NVLS_ENABLE=0 +ENV NCCL_CROSS_NIC=0 +ENV NCCL_ALGO=Ring +ENV NCCL_PROTO=Simple +ENV NCCL_DEBUG=INFO +ENV NCCL_NET_GDR_LEVEL=PIX +ENV NCCL_P2P_PXN_LEVEL=0 +ENV NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION +ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64" +ENV NCCL_GPUDIRECTTCPX_FORCE_ACK=0 +ENV NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000 +ENV NCCL_DYNAMIC_CHUNK_SIZE=524288 +ENV NCCL_P2P_NET_CHUNKSIZE=524288 +ENV NCCL_P2P_PCI_CHUNKSIZE=524288 +ENV NCCL_P2P_NVL_CHUNKSIZE=1048576 +ENV NCCL_NSOCKS_PERTHREAD=4 +ENV NCCL_SOCKET_NTHREADS=1 +ENV NCCL_MAX_NCHANNELS=12 +ENV NCCL_MIN_NCHANNELS=12 +ENV NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000 +ENV NCCL_SOCKET_IFNAME=eth0 +ENV NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177" +ENV NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191" +ENV NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4 +ENV NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0 +# Might require adjusting based on where the TCPX socket is mounted +ENV NCCL_GPUDIRECTTCPX_UNIX_CLIENT_PREFIX=/run/tcpx diff --git a/sample_workloads/jax-tcpx-base-image/README.md b/sample_workloads/jax-tcpx-base-image/README.md new file mode 100644 index 00000000..70bdef20 --- /dev/null +++ b/sample_workloads/jax-tcpx-base-image/README.md @@ -0,0 +1,14 @@ +# Base image with JAX and TCPX config + +A base image to use with Jax and optimized TCPX config + +Image location: +``` +us-docker.pkg.dev/$PROJECT_ID/jax-gpu/base-tcpx:0.4.21 +``` + +## Pushing new image +``` +gcloud builds submit --config=cloudbuild.yaml \ + --substitutions=_VERSION=0.4.21 --project gce-ai-infra +``` \ No newline at end of file diff --git a/sample_workloads/jax-tcpx-base-image/cloudbuild.yaml b/sample_workloads/jax-tcpx-base-image/cloudbuild.yaml new file mode 100644 index 00000000..edcd6959 --- /dev/null +++ b/sample_workloads/jax-tcpx-base-image/cloudbuild.yaml @@ -0,0 +1,6 @@ +# Build and push image to Artifact Registry +steps: +- name: 'gcr.io/cloud-builders/docker' + args: [ 'build', '--build-arg', 'VERSION=${_VERSION}', '-t', 'us-docker.pkg.dev/$PROJECT_ID/jax-gpu/base-tcpx:${_VERSION}', '.' ] +images: +- 'us-docker.pkg.dev/$PROJECT_ID/jax-gpu/base-tcpx:${_VERSION}' \ No newline at end of file diff --git a/sample_workloads/jax-tcpx-base-image/entrypoint.sh b/sample_workloads/jax-tcpx-base-image/entrypoint.sh new file mode 100644 index 00000000..3b4686d4 --- /dev/null +++ b/sample_workloads/jax-tcpx-base-image/entrypoint.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -x + +function on_script_completion { + # semaphore to cleanly exit hardware utilization monitor + touch /run/tcpx/workload_terminated +} +trap on_script_completion EXIT + +exec "$@"