Skip to content

Commit 0a556a6

Browse files
Add TPU Trillium Multi-Host RayCluster for MaxText (#1807)
* Add V6E multi-host RayCluster for MaxText Signed-off-by: Ryan O'Leary <[email protected]> * add header Signed-off-by: Ryan O'Leary <[email protected]> * Add Ray Train script Signed-off-by: Ryan O'Leary <[email protected]> * update maxtext trainer script Signed-off-by: Ryan O'Leary <[email protected]> * Add Dockerfile and region tags Signed-off-by: Ryan O'Leary <[email protected]> * add license header and new line Signed-off-by: Ryan O'Leary <[email protected]> --------- Signed-off-by: Ryan O'Leary <[email protected]> Co-authored-by: Mofi Rahman <[email protected]>
1 parent 5799e55 commit 0a556a6

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START gke_ai_ml_gke_ray_ray_train_maxtext_dockerfile]
16+
# Start from a Ray base image which includes JaxTrainer API.
17+
# Maxtext with TPU requires Python 3.12.
18+
FROM rayproject/ray:2.49.1-py312
19+
20+
USER root
21+
RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
22+
23+
RUN sudo apt-get update -y \
24+
&& sudo apt-get install --no-install-recommends -y git \
25+
&& sudo rm -rf /var/lib/apt/lists/*
26+
27+
WORKDIR /app
28+
29+
# Clone the Maxtext repo and build from source, installing TPU dependencies.
30+
RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
31+
32+
RUN pip install --no-cache-dir uv
33+
34+
RUN cd maxtext && \
35+
uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
36+
install_maxtext_github_deps
37+
38+
# Copy the Ray Maxtext trainer to run on the remote container.
39+
COPY maxtext_ray_trainer.py .
40+
41+
RUN chown -R ray:ray .
42+
ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
43+
USER ray
44+
# [END gke_ai_ml_gke_ray_ray_train_maxtext_dockerfile]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START gke_ai_ml_gke_ray_ray_train_maxtext_ray_trainer]
16+
import os
17+
from absl import app
18+
import logging
19+
from typing import Sequence
20+
import ray
21+
from ray.train.v2.api.config import ScalingConfig, RunConfig
22+
from ray.train.v2.jax import JaxTrainer
23+
24+
def train_loop_per_worker(config):
25+
from MaxText.train import main as maxtext_main
26+
27+
argv = config["argv"]
28+
maxtext_main(argv)
29+
30+
def main(argv: Sequence[str]):
31+
trainer = JaxTrainer(
32+
train_loop_per_worker=train_loop_per_worker,
33+
train_loop_config={"argv": argv},
34+
scaling_config=ScalingConfig(
35+
use_tpu=True,
36+
num_workers=4,
37+
topology="4x4",
38+
accelerator_type="TPU-V6E",
39+
resources_per_worker={"TPU": 4},
40+
placement_strategy="SPREAD",
41+
),
42+
run_config=RunConfig(
43+
name="maxtext_jaxtrainer",
44+
worker_runtime_env={
45+
"env_vars": {
46+
"JAX_PLATFORMS": "tpu",
47+
"ENABLE_PJRT_COMPATIBILITY": "true",
48+
"TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
49+
"TPU_SLICE_BUILDER_DUMP_ICI": "true",
50+
"XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
51+
}
52+
},
53+
),
54+
)
55+
result = trainer.fit()
56+
logging.info("Training complete!")
57+
ray.shutdown()
58+
59+
if __name__ == "__main__":
60+
app.run(main)
61+
# [END gke_ai_ml_gke_ray_ray_train_maxtext_ray_trainer]
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START gke_ai_ml_gke_ray_ray_train_maxtext_ray_cluster_tpu_v6e_16]
16+
apiVersion: ray.io/v1
17+
kind: RayCluster
18+
metadata:
19+
name: maxtext-tpu-cluster
20+
spec:
21+
headGroupSpec:
22+
rayStartParams: {}
23+
template:
24+
metadata:
25+
annotations:
26+
gke-gcsfuse/volumes: "true"
27+
gke-gcsfuse/cpu-limit: "0"
28+
gke-gcsfuse/memory-limit: "0"
29+
gke-gcsfuse/ephemeral-storage-limit: "0"
30+
spec:
31+
serviceAccountName: ${KSA_NAME}
32+
containers:
33+
- name: ray-head
34+
image: ${DOCKER_IMAGE}
35+
imagePullPolicy: IfNotPresent
36+
ports:
37+
- containerPort: 6379
38+
name: gcs-server
39+
- containerPort: 8265
40+
name: dashboard
41+
- containerPort: 10001
42+
name: client
43+
resources:
44+
limits:
45+
memory: "16Gi"
46+
requests:
47+
cpu: "8"
48+
memory: "16Gi"
49+
volumeMounts:
50+
- name: gcs-fuse-csi-ephemeral
51+
mountPath: /data
52+
- name: dshm
53+
mountPath: /dev/shm
54+
volumes:
55+
- name: gcs-fuse-cache
56+
emptyDir:
57+
medium: Memory
58+
- name: dshm
59+
emptyDir:
60+
medium: Memory
61+
- name: gcs-fuse-csi-ephemeral
62+
csi:
63+
driver: gcsfuse.csi.storage.gke.io
64+
volumeAttributes:
65+
bucketName: ${GS_BUCKET}
66+
mountOptions: "implicit-dirs"
67+
workerGroupSpecs:
68+
- replicas: 1
69+
numOfHosts: 4
70+
groupName: tpu-group
71+
rayStartParams: {}
72+
template:
73+
metadata:
74+
annotations:
75+
gke-gcsfuse/volumes: "true"
76+
gke-gcsfuse/cpu-limit: "0"
77+
gke-gcsfuse/memory-limit: "0"
78+
gke-gcsfuse/ephemeral-storage-limit: "0"
79+
spec:
80+
serviceAccountName: ${KSA_NAME}
81+
containers:
82+
- name: ray-worker
83+
image: ${DOCKER_IMAGE}
84+
imagePullPolicy: IfNotPresent
85+
resources:
86+
limits:
87+
memory: 200G
88+
google.com/tpu: "4"
89+
requests:
90+
cpu: "8"
91+
memory: 200G
92+
google.com/tpu: "4"
93+
env:
94+
- name: JAX_PLATFORMS
95+
value: tpu
96+
- name: ENABLE_PJRT_COMPATIBILITY
97+
value: "true"
98+
volumeMounts:
99+
- name: gcs-fuse-csi-ephemeral
100+
mountPath: /data
101+
- name: dshm
102+
mountPath: /dev/shm
103+
volumes:
104+
- name: gcs-fuse-cache
105+
emptyDir:
106+
medium: Memory
107+
- name: dshm
108+
emptyDir:
109+
medium: Memory
110+
- name: gcs-fuse-csi-ephemeral
111+
csi:
112+
driver: gcsfuse.csi.storage.gke.io
113+
volumeAttributes:
114+
bucketName: ${GS_BUCKET}
115+
mountOptions: "implicit-dirs"
116+
nodeSelector:
117+
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
118+
cloud.google.com/gke-tpu-topology: 4x4
119+
# [END gke_ai_ml_gke_ray_ray_train_maxtext_ray_cluster_tpu_v6e_16]

0 commit comments

Comments
 (0)