-
Notifications
You must be signed in to change notification settings - Fork 902
feat(runtimes): KEP-2442-jax-runtime #2878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 32 commits
32cb4cd
30411f7
cc896d2
afe56ca
cd51833
feccb57
828e55f
65d84cc
e7beb80
a249d1d
cb51c95
1bf5e03
8c2e746
51e709b
9836a7d
f8bfb05
3d695cc
b98e7b9
86f3065
1d90a48
f31f57b
0a5b2f9
1e23994
c0be33a
b161972
fc0c5e7
14ab32f
660ff1b
76842d3
a761a28
a255feb
1f3bbb5
b17a223
9839e3e
c864c14
9fd682e
5c19e97
fc78b9c
25fd59e
ad957d2
9b3f6c6
c7f23cc
15cdd84
60d1efb
1262f40
d9f6359
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,54 @@ | ||||||||||||||||||||
| FROM ghcr.io/nvidia/jax:jax as gpu-base | ||||||||||||||||||||
|
||||||||||||||||||||
| ENV DEBIAN_FRONTEND=noninteractive | ||||||||||||||||||||
|
|
||||||||||||||||||||
| RUN apt-get update && \ | ||||||||||||||||||||
|
||||||||||||||||||||
| apt-get install -y --no-install-recommends \ | ||||||||||||||||||||
| build-essential \ | ||||||||||||||||||||
| cmake \ | ||||||||||||||||||||
| git \ | ||||||||||||||||||||
| curl \ | ||||||||||||||||||||
|
||||||||||||||||||||
| libgoogle-glog-dev \ | ||||||||||||||||||||
| libgflags-dev \ | ||||||||||||||||||||
| libprotobuf-dev \ | ||||||||||||||||||||
| protobuf-compiler \ | ||||||||||||||||||||
| python3-dev \ | ||||||||||||||||||||
| python3-pip \ | ||||||||||||||||||||
| python3-setuptools && \ | ||||||||||||||||||||
| rm -rf /var/lib/apt/lists/* | ||||||||||||||||||||
|
||||||||||||||||||||
| python3-dev pip && rm -f /usr/bin/python && ln -s /usr/bin/python3 /usr/bin/python && rm -rf /var/lib/apt/lists/* |
That should allow you to use python instead of python3 as an entrypoint.
Outdated
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Dockerfile uses a base image from ghcr.io/nvidia/jax:jax, but then reinstalls JAX packages (lines 19-24). This could lead to version conflicts or redundant installations. Consider either using a minimal base image (like python:3.x or nvidia/cuda:x.x-base) and installing JAX explicitly, or use the NVIDIA JAX image without reinstalling packages. Additionally, the base image tag :jax is ambiguous and could change unexpectedly - consider using a specific version tag.
Outdated
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These pip install commands pull third-party Python packages (numpy, jax, jaxlib, jax[cuda12_pip]) from external indexes without pinning versions or verifying hashes. If any of these packages or their transitive dependencies are compromised (e.g., maintainer account takeover or registry poisoning), a malicious update would be automatically trusted at build time and executed with full privileges in this runtime image. To reduce supply-chain risk, pin each dependency to immutable versions or hashes (e.g., via a locked requirements file or hash-checking) and use only trusted package indexes.
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TPU installation logic (lines 28-31) uses || with echo statements, which means if the first pip install fails, it will echo a message and then try the second install. However, if the second install also fails, the build will continue without error. This masks installation failures. Consider using explicit error handling or removing the || echo patterns to ensure build failures are visible.
| RUN pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html || \ | |
| echo "TPU support not available" && \ | |
| pip install --no-cache-dir libtpu-nightly || \ | |
| echo "libtpu-nightly not available" | |
| RUN set -e; \ | |
| if ! pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; then \ | |
| echo "TPU support not available, attempting to install libtpu-nightly"; \ | |
| pip install --no-cache-dir libtpu-nightly; \ | |
| fi |
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pip install pipeline installs jax[tpu] and then falls back to libtpu-nightly from external indexes, again without pinning versions or verifying hashes. A compromised release of these TPU-related packages would be automatically pulled during image builds and executed with high privileges, enabling a supply-chain compromise of workloads using this runtime. Mitigate this by pinning to specific, vetted versions (or hashes) and using a controlled package index or mirror for these artifacts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need gloo ?
Do you know if this image can support GPU, CPU, and TPU workloads at the same time?
ghcr.io/nvidia/jax:jax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't know what you mean, shall i simplify image to only GPU?
Outdated
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pip install of absl-py and kubernetes uses unpinned versions from external package indexes, which introduces supply-chain risk for this production runtime image. If an attacker publishes a malicious version of one of these packages (or a dependency), future image builds will automatically incorporate and execute the malicious code with the container's privileges. Pin these dependencies to immutable versions or hashes and fetch them from a trusted, controlled index to limit this attack surface.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need kubernetes and absl-py?
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,51 @@ | ||||||||||
| #!/bin/bash | ||||||||||
|
|
||||||||||
| echo "=== VERSION INFORMATION ===" | ||||||||||
| echo "Python version: $(python3 --version)" | ||||||||||
| echo "Pip version: $(pip --version)" | ||||||||||
| echo "" | ||||||||||
|
|
||||||||||
| echo "=== JAX Information ===" | ||||||||||
| python3 -c "import jax; print(f'JAX version: {jax.__version__}')" | ||||||||||
| python3 -c "import jaxlib; print(f'JAXLib version: {jaxlib.__version__}')" | ||||||||||
| echo "" | ||||||||||
|
|
||||||||||
| echo "=== NCCL Information ===" | ||||||||||
| if dpkg -l | grep -q libnccl2; then | ||||||||||
| echo "NCCL version: $(dpkg -l | grep libnccl2 | awk '{print $3}')" | ||||||||||
| else | ||||||||||
| echo "NCCL: Not available via package manager" | ||||||||||
| fi | ||||||||||
| if command -v nvcc &> /dev/null; then | ||||||||||
| echo "CUDA version: $(nvcc --version | grep 'release' | awk '{print $5}')" | ||||||||||
| fi | ||||||||||
| echo "" | ||||||||||
|
|
||||||||||
| echo "=== TPU Information ===" | ||||||||||
| python3 -c " | ||||||||||
| try: | ||||||||||
| import jax | ||||||||||
| devices = jax.devices() | ||||||||||
| tpu_devices = [d for d in devices if d.platform == 'tpu'] | ||||||||||
| if tpu_devices: | ||||||||||
| print(f'TPU devices found: {len(tpu_devices)}') | ||||||||||
| for d in tpu_devices: | ||||||||||
| print(f' - {d}') | ||||||||||
| else: | ||||||||||
| print('No TPU devices found') | ||||||||||
| except Exception as e: | ||||||||||
| print(f'Error checking TPU devices: {e}') | ||||||||||
| " | ||||||||||
| echo "" | ||||||||||
|
|
||||||||||
| echo "=== Gloo Information ===" | ||||||||||
| echo "Gloo: Built from source (commit 43b7acbf372cdce14075f3526e39153b7e433b53)" | ||||||||||
| if [ -f /usr/local/lib/libgloo.a ]; then | ||||||||||
| echo "Gloo library: Found at /usr/local/lib/libgloo.so" | ||||||||||
|
||||||||||
| echo "Gloo library: Found at /usr/local/lib/libgloo.so" | |
| echo "Gloo library: Found static library at /usr/local/lib/libgloo.a" | |
| elif [ -f /usr/local/lib/libgloo.so ]; then | |
| echo "Gloo library: Found shared library at /usr/local/lib/libgloo.so" |
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| apiVersion: trainer.kubeflow.org/v1alpha1 | ||
| kind: ClusterTrainingRuntime | ||
| metadata: | ||
| name: jax-distributed | ||
| labels: | ||
| trainer.kubeflow.org/framework: jax | ||
| spec: | ||
| mlPolicy: | ||
| numNodes: 1 | ||
| jax: {} | ||
| template: | ||
| spec: | ||
| replicatedJobs: | ||
| - name: node | ||
| template: | ||
| metadata: | ||
| labels: | ||
| trainer.kubeflow.org/trainjob-ancestor-step: trainer | ||
| spec: | ||
| template: | ||
| spec: | ||
| containers: | ||
| - name: node | ||
| image: ghcr.io/kubeflow/trainer/jax-runtime |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please can you keep the order of image types, and move it after
mlx-runtime?