-
Notifications
You must be signed in to change notification settings - Fork 69
JAX-vLLM Offloading k8s (GKE) #1797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 63 commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
6db8716
Add k8s JAX-vLLM offloading example
aybchan b398510
Update gateway URL
aybchan e0a1b67
Add two-node manifest
aybchan 485e2a3
Add 8:8 logs
aybchan fd9c38f
Add hard-coded 2x node jobset example
aybchan 0000997
patch vLLM weight loader
yhtang 82fe0ce
bump tunix version
yhtang a3f36ff
Add jax[k8s] extras to install
aybchan e069f1b
Organize deployment manifests
aybchan aa926cb
Set missing env. vars
aybchan 5591ce4
address PR comments
yhtang d45fa3a
address PR comments
yhtang 771f97d
Remove debug trace
aybchan db4861b
Add JAX-vLLM workflow
aybchan 20802b8
Fix JobSet command
aybchan 794ff86
Add xpk patch, update env file, patch composite action
aybchan bc9d877
Enable image pull secret set
aybchan b64763f
Set jobset dot env path
aybchan f8cd259
Refactor CI workflows
aybchan b75f09f
Fix workflow
aybchan ed9d8b0
Fix workflow
aybchan 8c85fe7
Fix workflow
aybchan 19391d6
Merge branch 'yhtang/vllm-0.11-bump' into aybchan/jax-vllm-offloading…
aybchan 7e5e1ce
Add build to pipeline
aybchan 2ab8e10
Set test build
aybchan 77d6c82
Disable arm64 build for now
aybchan 78800df
Update deprecataed variable
aybchan a081257
revert change related to a vllm tokenizer load failure that is no lon…
yhtang 8b3440f
remove tunix version pin
yhtang 4372984
revert JAX version bump
yhtang b8c7b9c
Merge branch 'yhtang/vllm-0.11-bump' into aybchan/jax-vllm-offloading…
aybchan 989cb24
Merge workflow definitions
aybchan e07922c
Set run 70B transfer
aybchan 0cc6920
Set arm64 and amd64 build separately
aybchan bd49355
Fix artifact name collision
aybchan ffa6706
Update keyword name due to tunix 974da5
aybchan f111e2f
Make xpk composite action changes backwards compatible
aybchan 2bd3206
Add working k8s GRPO recipe
aybchan 18950d7
Update GRPO workflow
aybchan df32d3b
Fix workflow
aybchan bdaafd0
Fix inline command
aybchan aba0842
Remove debug logs
aybchan 4f7c0ec
Set consume step output
aybchan a1f5a4b
Set workload name strictly lower case
aybchan 9c4b2bd
Handle invalid jobset name
aybchan db46141
Remove unnecessary unbound variable
aybchan 9959cde
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan 8d211d7
Update trigger
aybchan 0c73a51
Remove build-arm64 job
aybchan 45e7c53
Update job dependencies
aybchan 869d626
Fix unbound variable
aybchan e9cfd24
Refactor gke-xpk action to set environment variables from string
aybchan 288d621
Remove environment variable setting in start up script
aybchan 83511ee
Set multi-line env var delimiter
aybchan 1dee288
Update image var reference
aybchan 13ed7b5
Fix script
aybchan edc596b
Fix unbound variable reference
aybchan c35b73e
Handle nested env arg
aybchan b9b428f
Update variable value
aybchan eb917e2
Cleanup workflow definition, adapt nccl, maxtext envs
aybchan 9f6a450
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan 031882f
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan 06e1111
Merge branch 'main' into aybchan/jax-vllm-offloading-k8s
aybchan 4a41013
Remove test files
aybchan f919b29
Refactor env exporting due to xpk jobset duplication
aybchan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| kubectl apply -f transfer/deployment/gateway-pod.yml | ||
| kubectl apply -f transfer/deployment/gateway-svc.yml | ||
|
|
||
| kubectl apply -f huggingface-secret.yml | ||
|
|
||
| kubectl apply -f transfer/deployment/rollout.yml | ||
| kubectl apply -f transfer/deployment/trainer.yml |
280 changes: 280 additions & 0 deletions
280
.github/gke-workflow/jax-vllm-offloading/grpo/jobset.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,280 @@ | ||
| apiVersion: jobset.x-k8s.io/v1alpha2 | ||
| kind: JobSet | ||
| metadata: | ||
| annotations: | ||
| name: jax-vllm-grpo | ||
| namespace: default | ||
| spec: | ||
| network: | ||
| enableDNSHostnames: true | ||
| publishNotReadyAddresses: true | ||
| replicatedJobs: | ||
| - name: slice-job | ||
| replicas: 1 | ||
| template: | ||
| metadata: {} | ||
| spec: | ||
| backoffLimit: 0 | ||
| completionMode: Indexed | ||
| completions: 2 | ||
| parallelism: 2 | ||
| template: | ||
| metadata: | ||
| annotations: | ||
| devices.gke.io/container.tcpxo-daemon: | | ||
| - path: /dev/nvidia0 | ||
| - path: /dev/nvidia1 | ||
| - path: /dev/nvidia2 | ||
| - path: /dev/nvidia3 | ||
| - path: /dev/nvidia4 | ||
| - path: /dev/nvidia5 | ||
| - path: /dev/nvidia6 | ||
| - path: /dev/nvidia7 | ||
| - path: /dev/nvidiactl | ||
| - path: /dev/nvidia-uvm | ||
| - path: /dev/dmabuf_import_helper | ||
| networking.gke.io/default-interface: eth0 | ||
| networking.gke.io/interfaces: |- | ||
| [ | ||
| {"interfaceName":"eth0","network":"default"}, | ||
| {"interfaceName":"eth1","network":"jtb-2025-10-07-gpunet-0-subnet"}, | ||
| {"interfaceName":"eth2","network":"jtb-2025-10-07-gpunet-1-subnet"}, | ||
| {"interfaceName":"eth3","network":"jtb-2025-10-07-gpunet-2-subnet"}, | ||
| {"interfaceName":"eth4","network":"jtb-2025-10-07-gpunet-3-subnet"}, | ||
| {"interfaceName":"eth5","network":"jtb-2025-10-07-gpunet-4-subnet"}, | ||
| {"interfaceName":"eth6","network":"jtb-2025-10-07-gpunet-5-subnet"}, | ||
| {"interfaceName":"eth7","network":"jtb-2025-10-07-gpunet-6-subnet"}, | ||
| {"interfaceName":"eth8","network":"jtb-2025-10-07-gpunet-7-subnet"} | ||
| ] | ||
| spec: | ||
| imagePullSecrets: | ||
| - name: jax-toolbox-ghcr | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| containers: | ||
| - name: gpu-image | ||
| image: ghcr.io/nvidia/jax-toolbox-internal:19751502075-jio-amd64 | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| imagePullPolicy: Always | ||
| command: | ||
| - bash | ||
| - -c | ||
| - | | ||
| pip install jax[k8s] | ||
| python -c " | ||
| import jax | ||
| jax.distributed.initialize() | ||
| print(jax.devices()) | ||
| print(jax.local_devices()) | ||
| assert jax.process_count() > 1 | ||
| assert len(jax.devices()) > len(jax.local_devices())" | ||
|
|
||
| PIDS=() | ||
| # hard-code split of vLLM-JAX on 1x node each on 2x slice jobset | ||
| if [ ${NODE_RANK} = "0" ]; then | ||
| echo "Starting gateway" | ||
| cd /opt/jtbx/jax-inference-offloading | ||
| python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log & | ||
| PIDS+=($!) | ||
|
|
||
| echo "Starting rollout" | ||
| cd /opt/jtbx/jax-inference-offloading/examples | ||
| python rollout.py 2>&1 | tee -a rollout.log & | ||
| PIDS+=($!) | ||
| else | ||
| echo "Starting trainer" | ||
| export MODEL_PATH=$(python "download_model.py" --hub=hf --model=${MODEL_NAME} --ignore="*.pth") | ||
| python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log & | ||
| PIDS+=($!) | ||
| fi | ||
|
|
||
| wait "${PIDS[@]}" | ||
| echo "All done" | ||
| env: | ||
| # jobset | ||
| - name: REPLICATED_JOB_NAME | ||
| valueFrom: | ||
| fieldRef: | ||
| fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] | ||
| - name: JOBSET_NAME | ||
| valueFrom: | ||
| fieldRef: | ||
| fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] | ||
| - name: NODE_RANK | ||
| valueFrom: | ||
| fieldRef: | ||
| fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] | ||
| - name: USE_GPUDIRECT | ||
| value: tcpxo | ||
| - name: GPUS_PER_NODE | ||
| value: "8" | ||
|
|
||
| - name: LD_LIBRARY_PATH | ||
| value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64" | ||
Steboss marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # huggingface | ||
| - name: HF_TOKEN | ||
| valueFrom: | ||
| secretKeyRef: | ||
| name: hf-token-secret | ||
| key: token | ||
| - name: MODEL_NAME | ||
| value: "meta-llama/Llama-3.1-8B-Instruct" | ||
| - name: SCRATCHDIR | ||
| value: "/opt/scratch" | ||
|
|
||
| # gateway | ||
| - name: GATEWAY_PORT | ||
| value: "50051" | ||
| - name: GATEWAY_URL | ||
| value: "$(JOBSET_NAME):$(GATEWAY_PORT)" | ||
|
|
||
| # JAX | ||
| - name: JAX_COORDINATOR_PORT | ||
| value: "3389" | ||
| - name: JAX_COORDINATOR_ADDRESS | ||
| value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):3389 | ||
|
|
||
| # CUDA | ||
| - name: CUDA_VISIBLE_DEVICES | ||
| value: "0,1,2,3,4,5,6,7" | ||
| - name: CUDA_DEVICE_ORDER | ||
| value: "PCI_BUS_ID" | ||
| - name: CUDA_DEVICE_MAX_CONNECTIONS | ||
| value: "16" | ||
|
|
||
| # vLLM | ||
| - name: VLLM_ENFORCE_EAGER | ||
| value: "1" | ||
| - name: VLLM_GPU_MEMORY_UTILIZATION | ||
| value: "0.7" | ||
| - name: VLLM_TENSOR_PARALLEL_SIZE | ||
| value: "8" | ||
| - name: VLLM_DISTRIBUTED_BACKEND | ||
| value: "mp" | ||
Steboss marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - name: VLLM_ATTENTION_BACKEND | ||
| value: "TRITON_ATTN" | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - name: VLLM_LOAD_FORMAT | ||
| value: "dummy" | ||
|
|
||
| # NCCL | ||
| - name: NCCL_NET_PLUGIN | ||
| value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so" | ||
| - name: NCCL_TUNER_PLUGIN | ||
| value: "none" | ||
| - name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| value: /dev/aperture_devices | ||
| - name: NCCL_CUMEM_ENABLE | ||
| value: "0" # https://docs.vllm.ai/en/v0.9.1/usage/troubleshooting.html#known-issues | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - name: NCCL_BUFFSIZE | ||
| value: "16777216" | ||
|
|
||
| # XLA | ||
| - name: XLA_PYTHON_CLIENT_MEM_FRACTION | ||
| value: "0.95" | ||
| - name: XLA_FLAGS | ||
| value: "--xla_gpu_enable_latency_hiding_scheduler=true | ||
| --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL | ||
| --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 | ||
| --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 | ||
| --xla_gpu_all_gather_combine_threshold_bytes=8589934592 | ||
| --xla_gpu_all_reduce_combine_threshold_bytes=8589934592" | ||
|
|
||
| # trainer | ||
| - name: TRANSFER_MODE | ||
| value: "grouped" | ||
| - name: USE_POLYMORPHIC_MESH | ||
| value: "0" | ||
| - name: JAX_COMPILATION_CACHE_DIR | ||
| value: /opt/jax-compilation | ||
| - name: JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS | ||
| value: "0.1" | ||
| - name: RUN_MODE | ||
| value: "timing" | ||
| - name: ROLLOUT_ENGINE | ||
| value: "vllm_gpu" | ||
| - name: GRPO_TRAIN_MICRO_BATCH_SIZE | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| value: "2" | ||
|
|
||
|
|
||
| ports: | ||
| - containerPort: 50051 | ||
| protocol: TCP | ||
| - containerPort: 3389 | ||
| protocol: TCP | ||
| resources: | ||
| limits: | ||
| nvidia.com/gpu: "8" | ||
| securityContext: | ||
aybchan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| privileged: true | ||
| volumeMounts: | ||
| - mountPath: /dev/aperture_devices | ||
| name: aperture-devices | ||
| - mountPath: /usr/local/nvidia | ||
| name: libraries | ||
| - mountPath: /dev/shm | ||
| name: dshm | ||
| - mountPath: /opt/scratch | ||
| name: scratch | ||
| dnsPolicy: ClusterFirstWithHostNet | ||
| initContainers: | ||
| - args: | ||
| - |- | ||
| set -ex | ||
| chmod 755 /fts/entrypoint_rxdm_container.sh | ||
| /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr | ||
| command: | ||
| - /bin/sh | ||
| - -c | ||
| env: | ||
| - name: LD_LIBRARY_PATH | ||
| value: /usr/local/nvidia/lib64 | ||
| image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.12 | ||
| imagePullPolicy: Always | ||
| name: tcpxo-daemon | ||
| resources: {} | ||
| restartPolicy: Always | ||
| securityContext: | ||
| capabilities: | ||
| add: | ||
| - NET_ADMIN | ||
| - NET_BIND_SERVICE | ||
| volumeMounts: | ||
| - mountPath: /usr/local/nvidia | ||
| name: libraries | ||
| - mountPath: /hostsysfs | ||
| name: sys | ||
| - mountPath: /hostprocsysfs | ||
| name: proc-sys | ||
| nodeSelector: | ||
| cloud.google.com/gke-accelerator: nvidia-h100-mega-80gb | ||
| priorityClassName: high | ||
| terminationGracePeriodSeconds: 30 | ||
| tolerations: | ||
| - key: nvidia.com/gpu | ||
| operator: Exists | ||
| - effect: NoSchedule | ||
| key: user-workload | ||
| operator: Equal | ||
| value: "true" | ||
| volumes: | ||
| - hostPath: | ||
| path: /home/kubernetes/bin/nvidia | ||
| name: libraries | ||
| - hostPath: | ||
| path: /sys | ||
| name: sys | ||
| - hostPath: | ||
| path: /proc/sys | ||
| name: proc-sys | ||
| - hostPath: | ||
| path: /dev/aperture_devices | ||
| name: aperture-devices | ||
| - emptyDir: | ||
| medium: Memory | ||
| name: dshm | ||
| - emptyDir: | ||
| sizeLimit: 2Gi | ||
| name: scratch | ||
| startupPolicy: | ||
| startupPolicyOrder: AnyOrder | ||
| successPolicy: | ||
| operator: All | ||
| ttlSecondsAfterFinished: 100000 | ||
8 changes: 8 additions & 0 deletions
8
.github/gke-workflow/jax-vllm-offloading/huggingface-secret.yml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| apiVersion: v1 | ||
| kind: Secret | ||
| metadata: | ||
| name: hf-token-secret | ||
| namespace: default | ||
| type: Opaque | ||
| stringData: | ||
| token: {{ HF_TOKEN}} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.