trainer: add JAX trainer guide for TPU#4343
trainer: add JAX trainer guide for TPU#4343google-oss-prow[bot] merged 2 commits intokubeflow:masterfrom
Conversation
|
Hi @siyuanfoundation. Thanks for your PR. I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with Once the patch is verified, the new status will be reflected by the I understand the commands that are listed here. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
🚫 This command cannot be processed. Only organization members or owners can use the commands. |
Signed-off-by: siyuanfoundation <sizhang@google.com>
|
/cc @andreyvelich |
andreyvelich
left a comment
There was a problem hiding this comment.
Looks great, overall lgtm, left a few comments.
Thank you for this @siyuanfoundation!
/assign @akshaychitneni @kubeflow/kubeflow-trainer-team
| ## JAX on TPU Overview | ||
|
|
||
| JAX on TPU requires a different runtime environment than GPU. Specifically: | ||
| - **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`). |
There was a problem hiding this comment.
Maybe you can add example how ClusterTrainingRuntime might look like?
Do you know if users want to set node selectors per job, or this is something that cluster admins can configure when they create reusable ClusterTrainingRuntime?
As @kaisoz mentioned in this PR, our default ClusterTrainingRuntime's image doesn't support TPUs: kubeflow/trainer#3151 (comment)
cc @kubeflow/kubeflow-trainer-team
| JAX on TPU requires a different runtime environment than GPU. Specifically: | ||
| - **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`). | ||
| - **Resources**: You must request `google.com/tpu` resources. | ||
| - **Node Selectors**: You must specify GKE-specific node selectors and topology for TPU nodes. |
There was a problem hiding this comment.
I know that JobSet also supports Exclusive Topology for TPU workload placement:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
Are there any interest from GKE team to showcase how this can be used with the TrainJob too?
Additionally, TPU multi-slice examples: kubernetes-sigs/jobset#1168
There was a problem hiding this comment.
the multi-slice support will depend on kubeflow/trainer#2318
|
|
||
| ### Node Selectors and Topology | ||
|
|
||
| When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels: |
| apiVersion: cloud.google.com/v1 | ||
| kind: ComputeClass |
There was a problem hiding this comment.
Does it require DRA driver to be installed? Shall we mention this?
There was a problem hiding this comment.
No, it does not.
|
/ok-to-test |
|
Approvals successfully granted for pending runs. |
Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Siyuan Zhang <10984162+siyuanfoundation@users.noreply.github.com>
There was a problem hiding this comment.
I think, we should be in good shape to merge this.
We can address this in the followup if needed: #4343 (comment)
Thanks for this work @siyuanfoundation!
/lgtm
/approve
|
[APPROVALNOTIFIER] This PR is APPROVED This pull-request has been approved by: andreyvelich The full list of commands accepted by this bot can be found here. The pull request process is described here DetailsNeeds approval from an approver in each of these files:
Approvers can indicate their approval by writing |
Description of Changes
This PR adds a JAX user guide describing how to run distributed JAX
training jobs with Kubeflow Trainer on TPUs.
Related Issues
Related: kubeflow/trainer#3183
Checklist