Skip to content

Commit 611d8ff

Browse files
author
Copybara
committed
Copybara import of gpu-recipes:
- 969ea416145c7a2486135ac16ca137fc49a3d268 Changing base container - f2f2d1a35b78252fc05ae2f6c16023c304f5c31e Updating MaxText version to fix grpc error after training... - 6232aaf0f3b073f5125c1ed29df34b81f5fa429c Merge "Clean the nemo-launcher-job helm chart" into main GitOrigin-RevId: 6232aaf0f3b073f5125c1ed29df34b81f5fa429c
1 parent d57d222 commit 611d8ff

File tree

13 files changed

+289
-139
lines changed

13 files changed

+289
-139
lines changed

src/docker/maxtext/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
# MaxText Nightly
1+
# MaxText Benchmarks
22

33
This is the Dockerfile for building a container image for MaxText/JAX training workloads.
44
Using the following versions:
5-
- BASE_IMAGE: ghcr.io/nvidia/jax:jax-2024-12-04
6-
- JAV_VERSION: 0.4.36.dev20241202
5+
- BASE_IMAGE: ghcr.io/nvidia/jax:maxtext-2025-01-10

src/docker/maxtext/cloudbuild.yml

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
11
steps:
2-
- name: 'gcr.io/cloud-builders/git'
3-
args: ['clone', 'https://github.com/AI-Hypercomputer/maxtext.git']
42
- name: 'docker'
5-
args: ['build',
6-
'--network', 'host',
7-
'--build-arg', 'MODE=${_MODE}',
8-
'--build-arg', 'JAX_VERSION=${_JAX_VERSION}',
9-
'--build-arg', 'DEVICE=${_DEVICE}',
10-
'--build-arg', 'BASEIMAGE=${_BASEIMAGE}',
11-
'-f', './maxtext_gpu_dependencies.Dockerfile',
12-
'-t', '${_ARTIFACT_REGISTRY}/maxtext-nightly',
13-
'.'
14-
]
15-
dir: 'maxtext' # Set the working directory to 'maxtext'
16-
options:
17-
substitution_option: 'ALLOW_LOOSE'
18-
substitutions:
19-
_MODE: 'nightly'
20-
_JAX_VERSION: '0.4.36.dev20241202'
21-
_DEVICE: 'gpu'
22-
_BASEIMAGE: 'ghcr.io/nvidia/jax:jax-2024-12-04'
3+
args:
4+
- 'build'
5+
- '--tag=${_ARTIFACT_REGISTRY}/maxtext-benchmark'
6+
- '--file=maxtext.Dockerfile'
7+
- '.'
8+
automapSubstitutions: true
9+
2310
images:
24-
- '${_ARTIFACT_REGISTRY}/maxtext-nightly'
11+
- '${_ARTIFACT_REGISTRY}/maxtext-benchmark'
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
FROM ghcr.io/nvidia/jax:maxtext-2025-01-10
16+
17+
# GCSfuse components (used to provide shared storage, not intended for high performance)
18+
RUN apt-get update && apt-get install --yes --no-install-recommends \
19+
ca-certificates \
20+
curl \
21+
gnupg \
22+
&& echo "deb https://packages.cloud.google.com/apt gcsfuse-buster main" \
23+
| tee /etc/apt/sources.list.d/gcsfuse.list \
24+
&& echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \
25+
| tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \
26+
&& curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - \
27+
&& apt-get update \
28+
&& apt-get install --yes gcsfuse \
29+
&& apt-get install --yes google-cloud-cli \
30+
&& apt-get install --yes dnsutils \
31+
&& apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
32+
&& mkdir /gcs

src/helm-charts/a3ultra/maxtext-training/templates/maxtext-launcher-job.yaml

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,32 @@ spec:
4949
gke-gcsfuse/memory-limit: "0"
5050
gke-gcsfuse/ephemeral-storage-limit: "0"
5151
{{- end}}
52+
{{- if not $root.Values.network.hostNetwork }}
53+
networking.gke.io/default-interface: "eth0"
54+
networking.gke.io/interfaces: |
55+
{{- if $root.Values.network.subnetworks }}
56+
[
57+
{{- range $i, $subnetwork := $root.Values.network.subnetworks }}
58+
{"interfaceName":"eth{{ $i }}","network":"{{ $subnetwork }}"}{{ eq $i 9 | ternary "" ","}}
59+
{{- end }}
60+
]
61+
{{- else }}
62+
[
63+
{"interfaceName":"eth0","network":"default"},
64+
{"interfaceName":"eth1","network":"{{ $root.Values.clusterName }}-sub-1"},
65+
{{- range $i := until 8 }}
66+
{"interfaceName":"eth{{ add 2 $i }}","network":"{{ $root.Values.clusterName }}-rdma-sub-{{ $i }}"}{{ eq $i 7 | ternary "" ","}}
67+
{{- end }}
68+
]
69+
{{- end }}
70+
{{- end }}
5271
spec:
5372
schedulingGates:
5473
- name: "gke.io/topology-aware-auto-scheduling"
74+
{{- if $root.Values.network.hostNetwork }}
5575
hostNetwork: true
5676
dnsPolicy: ClusterFirstWithHostNet
77+
{{- end }}
5778
subdomain: "{{.Release.Name}}"
5879
restartPolicy: Never
5980
{{ if $root.Values.targetNodes }}
@@ -84,8 +105,6 @@ spec:
84105
- name: workload-configuration
85106
configMap:
86107
name: "{{.Release.Name}}"
87-
- name: workload-terminated-volume
88-
emptyDir: {}
89108
- name: local-ssd
90109
hostPath:
91110
path: /mnt/stateful_partition/kube-ephemeral-ssd
@@ -143,8 +162,10 @@ spec:
143162
- name: workload
144163
image: "{{ $root.Values.workload.image }}"
145164
imagePullPolicy: Always
165+
{{- if $root.Values.network.hostNetwork }}
146166
securityContext:
147167
privileged: true
168+
{{- end }}
148169
env:
149170
- name: JOB_IDENTIFIER
150171
value: "{{ .Release.Name }}-{{ $timestamp }}-{{ $jobSuffix }}"
@@ -160,7 +181,7 @@ spec:
160181
- name: GCS_BUCKET
161182
value: {{ .bucketName }}
162183
{{- end }}
163-
184+
164185
# JAX-specific environment variables
165186
- name: JAX_COORDINATOR_ADDRESS
166187
value: "{{.Release.Name}}-0.{{.Release.Name}}.default.svc.cluster.local"
@@ -211,19 +232,13 @@ spec:
211232
- bash
212233
- -c
213234
- |
214-
function on_script_completion {
215-
touch /semaphore/workload_terminated
216-
}
217-
trap on_script_completion EXIT
218235
echo "Pod on $(hostname --fqdn) is running"
219236
echo "Pod is assigned job index of $JOB_COMPLETION_INDEX"
220237
echo "Job ID is $JOB_IDENTIFIER"
221238
222239
echo "Running nvidia-smi"
223240
nvidia-smi
224241
225-
226-
echo "Warning: Set LD_LIBRARY_PATH=$LD_LIBRARY_PATH to override the NCCL library"
227242
ldconfig $LD_LIBRARY_PATH
228243
echo "Added ${LD_LIBRARY_PATH} to ldconfig:"
229244
ldconfig -p | grep libcuda | sed 's/^/ /'
@@ -247,19 +262,16 @@ spec:
247262
echo "{{ . }}"
248263
{{- end }}
249264
250-
sleep 10 # <- Hack to allow some time for service to boot
251-
252265
export NODE_RANK=$JOB_COMPLETION_INDEX
253266
echo "Launching MaxText as node rank $NODE_RANK out of $NNODES nodes"
254267
255-
if [ "$NODE_RANK" -eq "1" ]; then
256-
echo "Launching nvidia-smi in daemon mode with (20 sec delay)"
257-
nvidia-smi dmon -d 20 -s pum &
258-
fi
259-
260268
echo "XLA Flags: $XLA_FLAGS"
261269
270+
sleep 10 # <- Allow some time for service to boot
271+
272+
echo "Setting JAX_COORDINATOR_IP"
262273
export JAX_COORDINATOR_IP=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1)
274+
echo $JAX_COORDINATOR_IP
263275
264276
# Parsing Configuration
265277
while IFS= read -r line || [[ -n "$line" ]]; \
@@ -284,8 +296,6 @@ spec:
284296
mountPath: /usr/local/nvidia
285297
- name: gib
286298
mountPath: /usr/local/gib
287-
- name: workload-terminated-volume
288-
mountPath: /semaphore
289299
- name: workload-configuration
290300
mountPath: /etc/workload-configuration
291301
- name: shared-memory

0 commit comments

Comments
 (0)