Skip to content

Commit 51732bd

Browse files
Update content/en/docs/components/trainer/user-guides/jax.md
Co-authored-by: Andrey Velichkevich <andrey.velichkevich@gmail.com> Signed-off-by: Siyuan Zhang <10984162+siyuanfoundation@users.noreply.github.com>
1 parent 7e03585 commit 51732bd

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

content/en/docs/components/trainer/user-guides/jax-tpu.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ This guide describes how to use TrainJob to train or fine-tune AI models with
1212
## Prerequisites
1313

1414
Before exploring this guide, make sure to follow:
15-
- [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/)
16-
- [GKE Cloud TPU documentation](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus) to set up a GKE cluster with TPU nodes. For example, for an autopilot GKE cluster, you can create a TPU custom ComputeClass like
15+
- [The Getting Started guide](/docs/components/trainer/user-guides/)
16+
- [GKE Cloud TPU documentation](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus) to set up a GKE cluster with TPU nodes. For example, for an autopilot GKE cluster, you can create a [TPU custom ComputeClass](https://docs.cloud.google.com/kubernetes-engine/docs/how-to/tpus#custom-compute-classes) like
1717
```
1818
apiVersion: cloud.google.com/v1
1919
kind: ComputeClass
@@ -348,14 +348,20 @@ print("\n".join(client.get_job_logs(name=job_id)))
348348

349349
### Node Selectors and Topology
350350

351-
When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels:
351+
When running on GKE, TPUs are managed via specific node pools and you must match the proper `node_selector` and `tolerations` to your TPU node pool labels.
352+
If you are using [custom ComputeClasses](https://docs.cloud.google.com/kubernetes-engine/docs/how-to/tpus#custom-compute-classes), add the following `node_selector` and `tolerations` to your TPU node pool labels:
352353

353354
| Label | Example Value |
354355
|-------|---------------|
355356
| `cloud.google.com/compute-class` | `tpu-multihost-v5-8` |
356357
| `cloud.google.com/gke-tpu-accelerator` | `tpu-v5-lite-podslice` |
357358
| `cloud.google.com/gke-tpu-topology` | `2x4` |
358359

360+
| Toleration Key | Toleration Operator | Toleration Effect |
361+
|-------|---------------|---------------|
362+
| `google.com/tpu` | `Exists` | `NoSchedule` |
363+
| `cloud.google.com/compute-class` | `Exists` | `NoSchedule` |
364+
359365
### Environment Variables
360366

361367
| Variable | Description |

content/en/docs/components/trainer/user-guides/jax.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TPU workloads are not supported in the default JAX runtime because installing bo
4242
and `jax[tpu]` in the same image leads to backend and plugin conflicts.
4343
A separate TPU-specific runtime is required.
4444

45-
Check out [the JAX on TPU guide](https://www.kubeflow.org/docs/components/trainer/user-guides/jax-tpu/)
45+
Check out [the JAX on TPU guide](/docs/components/trainer/user-guides/jax-tpu/)
4646
for more details on how to run JAX on Cloud TPU.
4747
{{% /alert %}}
4848

0 commit comments

Comments
 (0)