Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
32cb4cd
setup jax plugin
mahdikhashan Oct 5, 2025
30411f7
autogenerated files
mahdikhashan Oct 5, 2025
cc896d2
update jax dockerfile base image version
mahdikhashan Oct 5, 2025
afe56ca
fix dockerfile
mahdikhashan Oct 5, 2025
cd51833
fix ci
mahdikhashan Oct 5, 2025
feccb57
fix
mahdikhashan Oct 5, 2025
828e55f
add jax_distributed.yaml to runtimes
mahdikhashan Oct 5, 2025
65d84cc
wip: jax plugin
mahdikhashan Oct 5, 2025
e7beb80
downgrade jax image version
mahdikhashan Oct 6, 2025
a249d1d
wip: jax plugin
mahdikhashan Oct 6, 2025
cb51c95
fix tests
mahdikhashan Oct 6, 2025
1bf5e03
wip: fix ci
mahdikhashan Oct 6, 2025
8c2e746
wip: fix go tests
mahdikhashan Oct 6, 2025
51e709b
pass tests
mahdikhashan Oct 18, 2025
9836a7d
chore(ci): Enable Kubernetes API Linter (#2858)
astefanutti Oct 17, 2025
f8bfb05
feat(cache): KEP-2655 - Add build pipeline and address vulnerabilitie…
akshaychitneni Oct 18, 2025
3d695cc
setup jax plugin
mahdikhashan Oct 5, 2025
b98e7b9
wip: jax plugin
mahdikhashan Oct 5, 2025
86f3065
remove duplicate component
mahdikhashan Dec 29, 2025
1d90a48
remove pointers
mahdikhashan Dec 29, 2025
f31f57b
add jax to registry
mahdikhashan Dec 29, 2025
0a5b2f9
build jax image
mahdikhashan Dec 29, 2025
1e23994
add jax plugin
mahdikhashan Dec 29, 2025
c0be33a
fix version.sh filepath
mahdikhashan Dec 29, 2025
b161972
update api
mahdikhashan Dec 29, 2025
fc0c5e7
fix fmt
mahdikhashan Dec 29, 2025
14ab32f
fix tests
mahdikhashan Dec 29, 2025
660ff1b
fix lint error
mahdikhashan Dec 29, 2025
76842d3
fix lint error
mahdikhashan Dec 29, 2025
a761a28
fix lint error
mahdikhashan Dec 29, 2025
a255feb
fix lint errors
mahdikhashan Dec 29, 2025
1f3bbb5
update api
mahdikhashan Dec 29, 2025
b17a223
Merge branch 'kubeflow:master' into jax-runtime-impl
mahdikhashan Dec 29, 2025
9839e3e
add jax e2e-tests
mahdikhashan Dec 29, 2025
c864c14
update dockerfile
mahdikhashan Dec 29, 2025
9fd682e
fix image order
mahdikhashan Jan 6, 2026
5c19e97
remove gitkeep
mahdikhashan Jan 6, 2026
fc78b9c
update dockerfile
mahdikhashan Jan 6, 2026
25fd59e
use latest nvidia-jax image version
mahdikhashan Jan 12, 2026
ad957d2
remove validation block
mahdikhashan Jan 12, 2026
9b3f6c6
remove extra check
mahdikhashan Jan 12, 2026
c7f23cc
remove code
mahdikhashan Jan 12, 2026
15cdd84
update jax plugin
mahdikhashan Jan 12, 2026
60d1efb
update tests
mahdikhashan Jan 12, 2026
1262f40
remove tests
mahdikhashan Jan 12, 2026
d9f6359
remove jax from core tests for validation
mahdikhashan Jan 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build-and-push-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ jobs:
- component-name: mlx-runtime
dockerfile: cmd/runtimes/mlx/Dockerfile
platforms: linux/amd64
- component-name: jax-runtime
dockerfile: cmd/runtimes/jax/Dockerfile
platforms: linux/amd64,linux/arm64
- component-name: torchtune-trainer
dockerfile: cmd/trainers/torchtune/Dockerfile
platforms: linux/amd64,linux/arm64
Expand Down
20 changes: 20 additions & 0 deletions api/openapi-spec/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ spec:
description: mlPolicy provides the ML-specific parameters for the
model training.
properties:
jax:
description: jax defines the configuration for the JAX Runtime
type: object
mpi:
description: mpi defines the configuration for the MPI Runtime.
properties:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ spec:
description: mlPolicy provides the ML-specific parameters for the
model training.
properties:
jax:
description: jax defines the configuration for the JAX Runtime
type: object
mpi:
description: mpi defines the configuration for the MPI Runtime.
properties:
Expand Down
36 changes: 36 additions & 0 deletions cmd/runtimes/jax/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
FROM ghcr.io/nvidia/jax:jax-2026-01-04 as gpu-base
ENV DEBIAN_FRONTEND=noninteractive

RUN apt update && apt install -y --no-install-recommends \
build-essential \
cmake \
git \
libgoogle-glog-dev \
libgflags-dev \
libprotobuf-dev \
protobuf-compiler \
python3-dev pip && rm -f /usr/bin/python && \
ln -s /usr/bin/python3 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir --upgrade pip

FROM gpu-base as tpu-base

RUN pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html || \
echo "TPU support not available" && \
pip install --no-cache-dir libtpu-nightly || \
echo "libtpu-nightly not available"
Comment on lines +20 to +23
Copy link

Copilot AI Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TPU installation logic (lines 28-31) uses || with echo statements, which means if the first pip install fails, it will echo a message and then try the second install. However, if the second install also fails, the build will continue without error. This masks installation failures. Consider using explicit error handling or removing the || echo patterns to ensure build failures are visible.

Suggested change
RUN pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html || \
echo "TPU support not available" && \
pip install --no-cache-dir libtpu-nightly || \
echo "libtpu-nightly not available"
RUN set -e; \
if ! pip install --no-cache-dir "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; then \
echo "TPU support not available, attempting to install libtpu-nightly"; \
pip install --no-cache-dir libtpu-nightly; \
fi

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +23
Copy link

Copilot AI Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pip install pipeline installs jax[tpu] and then falls back to libtpu-nightly from external indexes, again without pinning versions or verifying hashes. A compromised release of these TPU-related packages would be automatically pulled during image builds and executed with high privileges, enabling a supply-chain compromise of workloads using this runtime. Mitigate this by pinning to specific, vetted versions (or hashes) and using a controlled package index or mirror for these artifacts.

Copilot uses AI. Check for mistakes.

FROM tpu-base as gloo-base

RUN git clone https://github.com/facebookincubator/gloo.git \
&& cd gloo \
&& git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \
&& mkdir build \
&& cd build \
&& cmake ../ \
&& make \
&& make install
Comment on lines +27 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need gloo ?
Do you know if this image can support GPU, CPU, and TPU workloads at the same time?
ghcr.io/nvidia/jax:jax

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know what you mean, shall i simplify image to only GPU?


FROM gloo-base as production
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ spec:
description: mlPolicy provides the ML-specific parameters for the
model training.
properties:
jax:
description: jax defines the configuration for the JAX Runtime
type: object
mpi:
description: mpi defines the configuration for the MPI Runtime.
properties:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ spec:
description: mlPolicy provides the ML-specific parameters for the
model training.
properties:
jax:
description: jax defines the configuration for the JAX Runtime
type: object
mpi:
description: mpi defines the configuration for the MPI Runtime.
properties:
Expand Down
24 changes: 24 additions & 0 deletions manifests/base/runtimes/jax_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
apiVersion: trainer.kubeflow.org/v1alpha1
kind: ClusterTrainingRuntime
metadata:
name: jax-distributed
labels:
trainer.kubeflow.org/framework: jax
spec:
mlPolicy:
numNodes: 1
jax: {}
template:
spec:
replicatedJobs:
- name: node
template:
metadata:
labels:
trainer.kubeflow.org/trainjob-ancestor-step: trainer
spec:
template:
spec:
containers:
- name: node
image: ghcr.io/kubeflow/trainer/jax-runtime
1 change: 1 addition & 0 deletions manifests/base/runtimes/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ kind: Kustomization
resources:
- deepspeed_distributed.yaml
- mlx_distributed.yaml
- jax_distributed.yaml
- torch_distributed.yaml
- torchtune
2 changes: 2 additions & 0 deletions manifests/overlays/runtimes/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ images:
newTag: latest
- name: ghcr.io/kubeflow/trainer/deepspeed-runtime
newTag: latest
- name: ghcr.io/kubeflow/trainer/jax-runtime
newTag: latest
7 changes: 7 additions & 0 deletions pkg/apis/trainer/v1alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ type MLPolicySource struct {
// mpi defines the configuration for the MPI Runtime.
// +optional
MPI *MPIMLPolicySource `json:"mpi,omitempty"`

// jax defines the configuration for the JAX Runtime
// +optional
JAX *JAXMLPolicySource `json:"jax,omitempty"`
}

// TorchMLPolicySource represents a PyTorch runtime configuration.
Expand Down Expand Up @@ -239,6 +243,9 @@ type TorchElasticPolicy struct {
Metrics []autoscalingv2.MetricSpec `json:"metrics,omitempty"`
}

// JAXMLPolicySource represents a jax runtime configuration.
type JAXMLPolicySource struct{}

// MPIMLPolicySource represents a MPI runtime configuration.
type MPIMLPolicySource struct {
// numProcPerNode is the number of processes per node.
Expand Down
21 changes: 21 additions & 0 deletions pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 26 additions & 2 deletions pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions pkg/client/applyconfiguration/trainer/v1alpha1/mlpolicy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions pkg/client/applyconfiguration/trainer/v1alpha1/mlpolicysource.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading