Skip to content

trainer: add JAX trainer guide for TPU#4343

Merged
google-oss-prow[bot] merged 2 commits intokubeflow:masterfrom
siyuanfoundation:tpu
Mar 27, 2026
Merged

trainer: add JAX trainer guide for TPU#4343
google-oss-prow[bot] merged 2 commits intokubeflow:masterfrom
siyuanfoundation:tpu

Conversation

@siyuanfoundation
Copy link
Copy Markdown
Contributor

@siyuanfoundation siyuanfoundation commented Mar 16, 2026

Description of Changes

This PR adds a JAX user guide describing how to run distributed JAX
training jobs with Kubeflow Trainer on TPUs.

Related Issues

Related: kubeflow/trainer#3183

Checklist

image

@google-oss-prow
Copy link
Copy Markdown

Hi @siyuanfoundation. Thanks for your PR.

I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with /ok-to-test on its own line. Until that is done, I will not automatically test new commits in this PR, but the usual testing commands by org members will still work. Regular contributors should join the org to skip this step.

Once the patch is verified, the new status will be reflected by the ok-to-test label.

I understand the commands that are listed here.

Details

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

@google-oss-prow google-oss-prow bot added size/L area/trainer AREA: Kubeflow Trainer / Kubeflow Training Operator labels Mar 16, 2026
@github-actions
Copy link
Copy Markdown

🚫 This command cannot be processed. Only organization members or owners can use the commands.

Signed-off-by: siyuanfoundation <sizhang@google.com>
@siyuanfoundation
Copy link
Copy Markdown
Contributor Author

/cc @andreyvelich

@google-oss-prow google-oss-prow bot requested a review from andreyvelich March 16, 2026 20:59
Copy link
Copy Markdown
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, overall lgtm, left a few comments.
Thank you for this @siyuanfoundation!
/assign @akshaychitneni @kubeflow/kubeflow-trainer-team

## JAX on TPU Overview

JAX on TPU requires a different runtime environment than GPU. Specifically:
- **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`).
Copy link
Copy Markdown
Member

@andreyvelich andreyvelich Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can add example how ClusterTrainingRuntime might look like?
Do you know if users want to set node selectors per job, or this is something that cluster admins can configure when they create reusable ClusterTrainingRuntime?

As @kaisoz mentioned in this PR, our default ClusterTrainingRuntime's image doesn't support TPUs: kubeflow/trainer#3151 (comment)
cc @kubeflow/kubeflow-trainer-team

JAX on TPU requires a different runtime environment than GPU. Specifically:
- **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`).
- **Resources**: You must request `google.com/tpu` resources.
- **Node Selectors**: You must specify GKE-specific node selectors and topology for TPU nodes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that JobSet also supports Exclusive Topology for TPU workload placement:

alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool

Are there any interest from GKE team to showcase how this can be used with the TrainJob too?

Additionally, TPU multi-slice examples: kubernetes-sigs/jobset#1168

cc @GiuseppeTT @imreddy13

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the multi-slice support will depend on kubeflow/trainer#2318


### Node Selectors and Topology

When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This URL doesn't work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Comment on lines +18 to +19
apiVersion: cloud.google.com/v1
kind: ComputeClass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it require DRA driver to be installed? Shall we mention this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it does not.

@andreyvelich
Copy link
Copy Markdown
Member

/ok-to-test

@github-actions
Copy link
Copy Markdown

Approvals successfully granted for pending runs.

@siyuanfoundation siyuanfoundation changed the title [trainer] add Jax trainer guide for TPU trainer : add Jax trainer guide for TPU Mar 27, 2026
Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
Signed-off-by: Siyuan Zhang <10984162+siyuanfoundation@users.noreply.github.com>
@andreyvelich andreyvelich changed the title trainer : add Jax trainer guide for TPU trainer: add JAX trainer guide for TPU Mar 27, 2026
Copy link
Copy Markdown
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we should be in good shape to merge this.
We can address this in the followup if needed: #4343 (comment)
Thanks for this work @siyuanfoundation!
/lgtm
/approve

@google-oss-prow
Copy link
Copy Markdown

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: andreyvelich

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@google-oss-prow google-oss-prow bot merged commit 682bb9c into kubeflow:master Mar 27, 2026
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

approved area/trainer AREA: Kubeflow Trainer / Kubeflow Training Operator lgtm ok-to-test size/L

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants