Skip to content
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
6db8716
Add k8s JAX-vLLM offloading example
aybchan Nov 24, 2025
b398510
Update gateway URL
aybchan Nov 24, 2025
e0a1b67
Add two-node manifest
aybchan Nov 24, 2025
485e2a3
Add 8:8 logs
aybchan Nov 24, 2025
fd9c38f
Add hard-coded 2x node jobset example
aybchan Nov 25, 2025
0000997
patch vLLM weight loader
yhtang Nov 25, 2025
82fe0ce
bump tunix version
yhtang Nov 25, 2025
a3f36ff
Add jax[k8s] extras to install
aybchan Nov 25, 2025
e069f1b
Organize deployment manifests
aybchan Nov 26, 2025
aa926cb
Set missing env. vars
aybchan Nov 26, 2025
5591ce4
address PR comments
yhtang Nov 26, 2025
d45fa3a
address PR comments
yhtang Nov 26, 2025
771f97d
Remove debug trace
aybchan Nov 26, 2025
db4861b
Add JAX-vLLM workflow
aybchan Nov 26, 2025
20802b8
Fix JobSet command
aybchan Nov 26, 2025
794ff86
Add xpk patch, update env file, patch composite action
aybchan Nov 27, 2025
bc9d877
Enable image pull secret set
aybchan Nov 27, 2025
b64763f
Set jobset dot env path
aybchan Nov 27, 2025
f8cd259
Refactor CI workflows
aybchan Nov 27, 2025
b75f09f
Fix workflow
aybchan Nov 27, 2025
ed9d8b0
Fix workflow
aybchan Nov 27, 2025
8c85fe7
Fix workflow
aybchan Nov 27, 2025
19391d6
Merge branch 'yhtang/vllm-0.11-bump' into aybchan/jax-vllm-offloading…
aybchan Nov 27, 2025
7e5e1ce
Add build to pipeline
aybchan Nov 27, 2025
2ab8e10
Set test build
aybchan Nov 27, 2025
77d6c82
Disable arm64 build for now
aybchan Nov 27, 2025
78800df
Update deprecataed variable
aybchan Nov 28, 2025
a081257
revert change related to a vllm tokenizer load failure that is no lon…
yhtang Nov 28, 2025
8b3440f
remove tunix version pin
yhtang Nov 28, 2025
4372984
revert JAX version bump
yhtang Nov 28, 2025
b8c7b9c
Merge branch 'yhtang/vllm-0.11-bump' into aybchan/jax-vllm-offloading…
aybchan Nov 28, 2025
989cb24
Merge workflow definitions
aybchan Nov 28, 2025
e07922c
Set run 70B transfer
aybchan Nov 28, 2025
0cc6920
Set arm64 and amd64 build separately
aybchan Nov 28, 2025
bd49355
Fix artifact name collision
aybchan Nov 28, 2025
ffa6706
Update keyword name due to tunix 974da5
aybchan Nov 28, 2025
f111e2f
Make xpk composite action changes backwards compatible
aybchan Nov 28, 2025
2bd3206
Add working k8s GRPO recipe
aybchan Nov 28, 2025
18950d7
Update GRPO workflow
aybchan Nov 28, 2025
df32d3b
Fix workflow
aybchan Nov 28, 2025
bdaafd0
Fix inline command
aybchan Nov 28, 2025
aba0842
Remove debug logs
aybchan Nov 28, 2025
4f7c0ec
Set consume step output
aybchan Nov 28, 2025
a1f5a4b
Set workload name strictly lower case
aybchan Nov 28, 2025
9c4b2bd
Handle invalid jobset name
aybchan Nov 29, 2025
db46141
Remove unnecessary unbound variable
aybchan Nov 30, 2025
9959cde
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan Dec 18, 2025
8d211d7
Update trigger
aybchan Dec 18, 2025
0c73a51
Remove build-arm64 job
aybchan Dec 18, 2025
45e7c53
Update job dependencies
aybchan Dec 18, 2025
869d626
Fix unbound variable
aybchan Dec 18, 2025
e9cfd24
Refactor gke-xpk action to set environment variables from string
aybchan Dec 18, 2025
288d621
Remove environment variable setting in start up script
aybchan Dec 18, 2025
83511ee
Set multi-line env var delimiter
aybchan Dec 18, 2025
1dee288
Update image var reference
aybchan Dec 19, 2025
13ed7b5
Fix script
aybchan Dec 19, 2025
edc596b
Fix unbound variable reference
aybchan Dec 19, 2025
c35b73e
Handle nested env arg
aybchan Dec 19, 2025
b9b428f
Update variable value
aybchan Dec 19, 2025
eb917e2
Cleanup workflow definition, adapt nccl, maxtext envs
aybchan Dec 19, 2025
9f6a450
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan Dec 23, 2025
031882f
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan Jan 6, 2026
06e1111
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan Jan 13, 2026
4a41013
Remove test files
aybchan Jan 13, 2026
f919b29
Refactor env exporting due to xpk jobset duplication
aybchan Jan 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions .github/actions/gke-xpk/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ inputs:
required: false
type: string
MAIN_CONTAINER_NAME:
description: 'Name of the main contianer in an XPK JobSet (fixed)'
description: 'Name of the main contianer in an XPK JobSet (fixed in xpk)'
default: gpu-image
required: false
type: string
Expand Down Expand Up @@ -57,6 +57,11 @@ inputs:
required: false
default: 'nvidia-smi; free -h;'
type: string
ENVS:
description: 'Environment variables to pass to xpk for setting in JobSet (delimited by ;)'
required: false
default: ''
type: string
EXIT_COMMAND:
description: 'Command to set exit code'
required: false
Expand Down Expand Up @@ -124,6 +129,7 @@ runs:
- name: Set workload commands
shell: bash -x -e -u {0}
run: |
# install dependencies to enable export artifacts from container to gcs bucket
PRELUDE="
apt install -y ripgrep > /dev/null;
curl -LO https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz;
Expand All @@ -133,6 +139,7 @@ runs:

mkdir -p /usr/share/workload;
mkdir -p ${{ inputs.CONTAINER_OUTPUT_PATH }};"

# Work around GCP's deployment model that munges together three
# mostly unrelated things: (1) the host machine's CUDA driver/libs,
# (2) the version of NCCL installed on the host machine, and (3)
Expand All @@ -151,6 +158,7 @@ runs:
env;
"

# gsutil command to export logs from container's /opt/output to bucket
POSTLUDE="
./google-cloud-sdk/bin/gsutil cp -r ${{ inputs.CONTAINER_OUTPUT_PATH }}/ ${GCS_ARTIFACT_PATH}/node-0\$NODE_RANK;
${{ inputs.EXIT_COMMAND }}
Expand All @@ -163,9 +171,7 @@ runs:
POSTLUDE=$(echo ${POSTLUDE} | sed 's/\n/\ /g')
CMD=$(echo ${CMD} | sed 's/\n/\ /g')

echo "PRELUDE=${PRELUDE}" >> ${GITHUB_ENV}
echo "CMD=${CMD}" >> ${GITHUB_ENV}
echo "POSTLUDE=${POSTLUDE}" >> ${GITHUB_ENV}
echo "CMD=${PRELUDE} ${CMD} ${POSTLUDE}" >> ${GITHUB_ENV}

- name: Create workload on cluster with XPK
shell: bash -x -e -u {0}
Expand Down Expand Up @@ -194,16 +200,21 @@ runs:
}

if version_greater "${{ inputs.XPK_VERSION }}" "v0.10.0"; then
args+=(
--docker-image-pull-secret=${{ inputs.IMAGE_PULL_SECRET_NAME }}
--env="JAX_COORDINATOR_PORT=3389"
--env="JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):3389"
)
fi
args+=(
--docker-image-pull-secret=${{ inputs.IMAGE_PULL_SECRET_NAME }}
)

envs_flat=$(echo "${{ inputs.ENVS }}" | tr '\n' ' ')
IFS=';' read -ra env_vars <<< "${envs_flat}"
for env in "${env_vars[@]}"; do
env=$(echo "${env}" | xargs)
[[ -n "${env}" ]] && args+=(--env="${env}")
done
fi

python xpk.py workload create \
${args[@]} \
--command="${PRELUDE} ${CMD} ${POSTLUDE}"
"${args[@]}" \
--command="${CMD}"

- name: Wait for JobSet to unsuspend on cluster
shell: bash -u {0}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
apiVersion: v1
kind: Secret
metadata:
name: hf-token-secret
namespace: default
type: Opaque
stringData:
token: {{ HF_TOKEN}}
9 changes: 8 additions & 1 deletion .github/gke-workflow/xpk/v0.13.0/tcpxo_decorator.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/src/xpk/core/workload_decorators/tcpxo_decorator.py b/src/xpk/core/workload_decorators/tcpxo_decorator.py
index 3734f87..dc3b24a 100644
index 3734f87..4a35459 100644
--- a/src/xpk/core/workload_decorators/tcpxo_decorator.py
+++ b/src/xpk/core/workload_decorators/tcpxo_decorator.py
@@ -181,7 +181,9 @@ def update_gpu_containers(job_manifest):
Expand All @@ -13,3 +13,10 @@ index 3734f87..dc3b24a 100644
)
container['env'].append({
'name': 'NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY',
@@ -197,3 +199,6 @@ def update_gpu_containers(job_manifest):
container['volumeMounts'].append(
{'name': 'dshm', 'mountPath': '/dev/shm'}
)
+ container['env'].append(
+ {'name': 'HF_TOKEN', 'valueFrom': {'secretKeyRef': {'name': 'hf-token-secret', 'key': 'token'}}}
+ )
5 changes: 4 additions & 1 deletion .github/workflows/_test_maxtext_gke_xpk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ jobs:
IMAGE: ${{ env.MAXTEXT_IMAGE }}
IMAGE_PULL_SECRET_NAME: ${{ steps.store-token.outputs.token-name }}
WORKLOAD_NAME_PREFIX: ${{ env.WORKLOAD_NAME_PREFIX }}
COMMAND: |
ENVS: |
JAX_COORDINATOR_PORT=3389;
JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):\$(JAX_COORDINATOR_PORT);
console=/dev/stdout;
COMMAND: |
nsys-jax --capture-range=cudaProfilerApi
--capture-range-end=stop
-o /opt/output/profile.zip
Expand Down
19 changes: 10 additions & 9 deletions .github/workflows/_test_nccl_gke.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,17 @@ jobs:
IMAGE: ${{ env.BASE_IMAGE }}
IMAGE_PULL_SECRET_NAME: ${{ steps.store-token.outputs.token-name }}
WORKLOAD_NAME_PREFIX: ${{ steps.workload-name.outputs.WORKLOAD_PREFIX }}
ENVS: |
JAX_COORDINATOR_PORT=3389;
JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):\$(JAX_COORDINATOR_PORT);
NHOSTS=${{ env.NHOSTS }};
NCCL_LIB_DIR=/opt/nvida/nccl/lib;
SCRIPT_DIR=/scripts;
NCCL_MINBYTES=${{ env.NCCL_MINBYTES }};
NCCL_MAXBYTES=${{ env.NCCL_MAXBYTES }};
NCCL_STEPFACTOR=${{ env.NCCL_STEPFACTOR }};
NCCL_ITERS=${{ env.NCCL_ITERS }};
COMMAND: |
export NHOSTS=${{ env.NHOSTS }};
export NCCL_LIB_DIR=/opt/nvida/nccl/lib;
export SCRIPT_DIR=/scripts;

export NCCL_MINBYTES=${{ env.NCCL_MINBYTES }};
export NCCL_MAXBYTES=${{ env.NCCL_MAXBYTES }};
export NCCL_STEPFACTOR=${{ env.NCCL_STEPFACTOR }};
export NCCL_ITERS=${{ env.NCCL_ITERS }};

service ssh restart;
console=/dev/stdout;
declare -a hosts=('nccl-test-host-1' 'nccl-test-host-2');
Expand Down
102 changes: 102 additions & 0 deletions .github/workflows/jax-vllm-offloading-gke-grpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
name: JAX-vLLM offloading GRPO (GKE, XPK)

on:
workflow_call:
inputs:
JAX_VLLM_OFFLOADING_IMAGE:
type: string
description: MaxText image from ghcr.io/nvidia
default: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64
required: false

jobs:
jax-vllm-offloading-grpo-gke-xpk:
runs-on: gke-a3mega
strategy:
matrix:
model: ["meta-llama/Llama-3.1-8B-Instruct"]
env:
WORKLOAD_NAME_PREPREFIX: vllm-grpo
JAX_VLLM_OFFLOADING_IMAGE: ${{ inputs.JAX_VLLM_OFFLOADING_IMAGE }}

NUM_NODES: 2

steps:
- uses: actions/checkout@v4

- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: K8s GHCR store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr

- name: Format workload name
id: workload-name
run: |
WORKLOAD_NAME_PREFIX="${WORKLOAD_NAME_PREPREFIX}-$(echo ${{ matrix.model }} | sed 's|.*/\(.*\)-[^-]*|\1|')"
WORKLOAD_NAME_PREFIX=$(echo ${WORKLOAD_NAME_PREFIX} | tr '.' '-')
echo "WORKLOAD_NAME_PREFIX=${WORKLOAD_NAME_PREFIX,,}" >> ${GITHUB_OUTPUT}

- name: Run XPK workload on cluster
uses: ./.github/actions/gke-xpk
with:
IMAGE: ${{ env.JAX_VLLM_OFFLOADING_IMAGE }}
IMAGE_PULL_SECRET_NAME: ${{ steps.store-token.outputs.token-name }}
WORKLOAD_NAME_PREFIX: ${{ steps.workload-name.outputs.WORKLOAD_NAME_PREFIX }}
ENVS: |
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;
CUDA_DEVICE_ORDER=PCI_BUS_ID;
CUDA_DEVICE_MAX_CONNECTIONS=16;
VLLM_ENFORCE_EAGER=1;
VLLM_GPU_MEMORY_UTILIZATION=0.7;
VLLM_TENSOR_PARALLEL_SIZE=8;
VLLM_DISTRIBUTED_BACKEND=mp;
VLLM_ATTENTION_BACKEND=TRITON_ATTN;
VLLM_LOAD_FORMAT=dummy;
NCCL_NET_PLUGIN=/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so;
NCCL_TUNER_PLUGIN=none;
MODEL_NAME=${{ matrix.model }};
NCCL_CUMEM_ENABLE=0;
NCCL_BUFFSIZE=16777216;
XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592;
TRANSFER_MODE=grouped;
USE_POLYMORPHIC_MESH=0;
JAX_COORDINATOR_PORT=3389;
JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):\$(JAX_COORDINATOR_PORT);
GATEWAY_PORT=50051;
GATEWAY_URL=\$(JOBSET_NAME):\$(GATEWAY_PORT);
OUTPUT_DIR=/opt/output;
LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64;

COMMAND: |
set -x;

pip install jax[k8s];
python -c 'import jax; jax.distributed.initialize(); print(jax.devices()); print(jax.local_devices()); assert jax.process_count() > 1; assert len(jax.devices()) > len(jax.local_devices());';

PIDS=();
if [ \${NODE_RANK} = 0 ]; then
echo Starting gateway;
cd /opt/jtbx/jax-inference-offloading;
python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log &
PIDS+=(\$!);

echo Starting rollout;
cd /opt/jtbx/jax-inference-offloading/examples;
python rollout.py 2>&1 | tee -a rollout.log &
PIDS+=(\$!);
else
export MODEL_PATH=\$(python download_model.py --hub=hf --model=\${MODEL_NAME} --ignore='*.pth');

echo Starting GRPO trainer;
python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log &
PIDS+=(\$!);
fi;

wait \${PIDS[@]};
EXIT_CODE=\$PIPESTATUS;
99 changes: 99 additions & 0 deletions .github/workflows/jax-vllm-offloading-gke-transfer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
name: JAX-vLLM offloading transfer (GKE, XPK)

on:
workflow_call:
inputs:
JAX_VLLM_OFFLOADING_IMAGE:
type: string
description: MaxText image from ghcr.io/nvidia
default: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64
required: false

jobs:
jax-vllm-offloading-transfer-gke-xpk:
runs-on: gke-a3mega
strategy:
matrix:
model: ["meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"]
env:
WORKLOAD_NAME_PREPREFIX: vllm-transf # due to 40 character workload name limit
JAX_VLLM_OFFLOADING_IMAGE: ${{ inputs.JAX_VLLM_OFFLOADING_IMAGE }}

NUM_NODES: 2

steps:
- uses: actions/checkout@v4

- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: K8s GHCR store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr

- name: Format workload name
id: workload-name
run: |
WORKLOAD_NAME_PREFIX="${WORKLOAD_NAME_PREPREFIX}-$(echo ${{ matrix.model }} | sed 's|.*/\(.*\)-[^-]*|\1|')"
WORKLOAD_NAME_PREFIX=$(echo ${WORKLOAD_NAME_PREFIX} | tr '.' '-')
echo "WORKLOAD_NAME_PREFIX=${WORKLOAD_NAME_PREFIX,,}" >> ${GITHUB_OUTPUT}

- name: Run XPK workload on cluster
uses: ./.github/actions/gke-xpk
with:
IMAGE: ${{ env.JAX_VLLM_OFFLOADING_IMAGE }}
IMAGE_PULL_SECRET_NAME: ${{ steps.store-token.outputs.token-name }}
WORKLOAD_NAME_PREFIX: ${{ steps.workload-name.outputs.WORKLOAD_NAME_PREFIX }}
ENVS: |
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;
CUDA_DEVICE_ORDER=PCI_BUS_ID;
CUDA_DEVICE_MAX_CONNECTIONS=16;
VLLM_ENFORCE_EAGER=1;
VLLM_GPU_MEMORY_UTILIZATION=0.7;
VLLM_TENSOR_PARALLEL_SIZE=8;
VLLM_DISTRIBUTED_BACKEND=mp;
VLLM_ATTENTION_BACKEND=TRITON_ATTN;
VLLM_LOAD_FORMAT=dummy;
MODEL_NAME=${{ matrix.model }};
NCCL_NET_PLUGIN=/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so;
NCCL_TUNER_PLUGIN=none;
NCCL_CUMEM_ENABLE=0;
NCCL_BUFFSIZE=16777216;
XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592;
TRANSFER_MODE=grouped;
USE_POLYMORPHIC_MESH=0;
JAX_COORDINATOR_PORT=3389;
JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):\$(JAX_COORDINATOR_PORT);
GATEWAY_PORT=50051;
GATEWAY_URL=\$(JOBSET_NAME):\$(GATEWAY_PORT);
OUTPUT_DIR=/opt/output;
LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64;
COMMAND: |
set -x;

pip install jax[k8s];
python -c 'import jax; jax.distributed.initialize(); print(jax.devices()); print(jax.local_devices()); assert jax.process_count() > 1; assert len(jax.devices()) > len(jax.local_devices());';

PIDS=();
if [ \${NODE_RANK} = 0 ]; then
echo Starting gateway;
cd /opt/jtbx/jax-inference-offloading;
python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log &
PIDS+=(\$!);

echo Starting rollout;
cd /opt/jtbx/jax-inference-offloading/examples;
python rollout.py 2>&1 | tee -a rollout.log &
PIDS+=(\$!);
else
echo Starting trainer;
python trainer.py 2>&1 | tee -a trainer.log &
PIDS+=(\$!);
fi;

wait \${PIDS[@]};
EXIT_CODE=\$PIPESTATUS;
Loading
Loading