Skip to content

Commit aa80053

Browse files
committed
gpu deps
1 parent 142b565 commit aa80053

File tree

2 files changed

+50
-47
lines changed

2 files changed

+50
-47
lines changed

dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ RUN if [ "$DEVICE" = "tpu" ]; then \
5252
python3 -m pip install 'google-tunix>=0.1.2'; \
5353
fi
5454

55+
# Temporarily downgrade to JAX=0.7.2 for GPU images
56+
RUN if [ "$DEVICE" = "gpu" ]; then \
57+
python3 -m pip install -U "jax[cuda12]==0.8.1"; \
58+
python3 -m pip install -U "transformer-engine-cu12" "transformer-engine-jax" "transformer-engine"; \
59+
fi
60+
5561
# Now copy the remaining code (source files that may change frequently)
5662
COPY . .
5763

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,37 @@
44
absl-py>=2.3.1
55
aiofiles>=25.1.0
66
aiohappyeyeballs>=2.6.1
7-
aiohttp>=3.13.1
7+
aiohttp>=3.13.2
88
aiosignal>=1.4.0
9-
annotated-doc>=0.0.3
9+
annotated-doc>=0.0.4
1010
annotated-types>=0.7.0
1111
antlr4-python3-runtime>=4.9.3
1212
anyio>=4.11.0
1313
aqtp>=0.9.0
14-
array-record>=0.8.2
15-
astroid>=4.0.1
14+
array-record>=0.8.3
15+
astroid>=4.0.2
1616
astunparse>=1.6.3
1717
attrs>=25.4.0
1818
auditwheel>=6.4.2
1919
black>=24.10.0
2020
blobfile>=3.1.0
2121
build>=1.3.0
22-
cachetools>=6.2.1
22+
cachetools>=6.2.2
2323
certifi>=2025.10.5
24-
cfgv>=3.4.0
24+
cfgv>=3.5.0
2525
charset-normalizer>=3.4.4
26-
cheroot>=11.0.0
26+
cheroot>=11.1.2
2727
chex>=0.1.91
28-
click>=8.3.0
28+
click>=8.3.1
2929
cloud-accelerator-diagnostics>=0.1.1
3030
cloud-tpu-diagnostics>=0.1.5
3131
cloudpickle>=3.1.1
3232
clu>=0.0.12
3333
colorama>=0.4.6
3434
contourpy>=1.3.3
35-
coverage>=7.11.0
35+
coverage>=7.12.0
3636
cycler>=0.12.1
37-
datasets>=4.3.0
37+
datasets>=4.4.1
3838
decorator>=5.2.1
3939
dill>=0.4.0
4040
distlib>=0.4.0
@@ -46,7 +46,7 @@ einshape>=1.0
4646
etils>=1.13.0
4747
evaluate>=0.4.6
4848
execnet>=2.1.1
49-
fastapi>=0.120.1
49+
fastapi>=0.121.3
5050
filelock>=3.20.0
5151
flatbuffers>=25.9.23
5252
flax>=0.12.0
@@ -55,28 +55,27 @@ frozenlist>=1.8.0
5555
fsspec>=2025.9.0
5656
gast>=0.6.0
5757
gcsfs>=2025.9.0
58-
google-api-core>=2.28.0
59-
google-api-python-client>=2.185.0
60-
google-auth-httplib2>=0.2.0
58+
google-api-core>=2.28.1
59+
google-api-python-client>=2.187.0
60+
google-auth-httplib2>=0.2.1
6161
google-auth-oauthlib>=1.2.2
62-
google-auth>=2.41.1
63-
google-benchmark>=1.9.4
64-
google-cloud-aiplatform>=1.122.0
62+
google-auth>=2.43.0
63+
google-cloud-aiplatform>=1.128.0
6564
google-cloud-appengine-logging>=1.7.0
6665
google-cloud-audit-log>=0.4.0
6766
google-cloud-bigquery>=3.38.0
68-
google-cloud-core>=2.4.3
67+
google-cloud-core>=2.5.0
6968
google-cloud-logging>=3.12.1
7069
google-cloud-monitoring>=2.28.0
7170
google-cloud-resource-manager>=1.15.0
72-
google-cloud-storage>=2.19.0
71+
google-cloud-storage>=3.6.0
7372
google-crc32c>=1.7.1
74-
google-genai>=1.46.0
73+
google-genai>=1.52.0
7574
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
7675
google-pasta>=0.2.0
77-
google-resumable-media>=2.7.2
78-
googleapis-common-protos>=1.71.0
79-
grain>=0.2.13
76+
google-resumable-media>=2.8.0
77+
googleapis-common-protos>=1.72.0
78+
grain>=0.2.14
8079
grpc-google-iam-v1>=0.14.3
8180
grpcio-status>=1.71.2
8281
grpcio>=1.75.1
@@ -99,19 +98,18 @@ importlib-resources>=6.5.2
9998
iniconfig>=2.1.0
10099
isort>=7.0.0
101100
jaraco-functools>=4.3.0
102-
jax-cuda12-pjrt>=0.8.0 ; sys_platform == 'linux'
103-
jax-cuda12-plugin>=0.8.0 ; sys_platform == 'linux'
104-
jax-triton>=0.3.0
105-
jax>=0.8.0
106-
jaxlib>=0.8.0
101+
jax-cuda12-pjrt>=0.8.1 ; sys_platform == 'linux'
102+
jax-cuda12-plugin>=0.8.1 ; sys_platform == 'linux'
103+
jax>=0.8.1
104+
jaxlib>=0.8.1
107105
jaxtyping>=0.3.3
108106
jinja2>=3.1.6
109107
joblib>=1.5.2
110108
jsonlines>=4.0.0
111109
keras>=3.11.3
112110
kiwisolver>=1.4.9
113111
libclang>=18.1.1
114-
libcst>=1.8.5
112+
libcst>=1.8.6
115113
lxml>=6.0.2
116114
markdown-it-py>=4.0.0
117115
markdown>=3.9
@@ -128,7 +126,7 @@ mpmath>=1.3.0
128126
msgpack>=1.1.2
129127
msgspec>=0.19.0
130128
multidict>=6.7.0
131-
multiprocess>=0.70.16
129+
multiprocess>=0.70.18
132130
mypy-extensions>=1.1.0
133131
namex>=0.1.0
134132
nest-asyncio>=1.6.0
@@ -157,7 +155,7 @@ opt-einsum>=3.4.0
157155
optax>=0.2.6
158156
optree>=0.17.0
159157
optype>=0.14.0
160-
orbax-checkpoint>=0.11.26
158+
orbax-checkpoint>=0.11.28
161159
packaging>=25.0
162160
pandas>=2.3.3
163161
parameterized>=0.9.0
@@ -167,7 +165,7 @@ pillow>=12.0.0
167165
platformdirs>=4.5.0
168166
pluggy>=1.6.0
169167
portpicker>=1.6.0
170-
pre-commit>=4.3.0
168+
pre-commit>=4.4.0
171169
prometheus-client>=0.23.1
172170
promise>=2.3
173171
propcache>=0.4.1
@@ -179,14 +177,14 @@ pyasn1-modules>=0.4.2
179177
pyasn1>=0.6.1
180178
pycnite>=2024.7.31
181179
pycryptodomex>=3.23.0
182-
pydantic-core>=2.41.4
183-
pydantic>=2.12.3
180+
pydantic-core>=2.41.5
181+
pydantic>=2.12.4
184182
pydot>=4.0.1
185183
pyelftools>=0.32
186184
pyglove>=0.4.5
187185
pygments>=2.19.2
188186
pyink>=24.10.1
189-
pylint>=4.0.2
187+
pylint>=4.0.3
190188
pyparsing>=3.2.5
191189
pyproject-hooks>=1.2.0
192190
pytest-xdist>=3.8.0
@@ -195,13 +193,13 @@ python-dateutil>=2.9.0.post0
195193
pytype>=2024.10.11
196194
pytz>=2025.2
197195
pyyaml>=6.0.3
198-
qwix>=0.1.1
199-
regex>=2025.10.23
196+
qwix>=0.1.4
197+
regex>=2025.11.3
200198
requests-oauthlib>=2.0.0
201199
requests>=2.32.5
202200
rich>=14.2.0
203201
rsa>=4.9.1
204-
safetensors>=0.6.2
202+
safetensors>=0.7.0
205203
scipy-stubs>=1.16.2.4
206204
scipy>=1.16.2
207205
sentencepiece>=0.2.1
@@ -214,7 +212,7 @@ simplejson>=3.20.2
214212
six>=1.17.0
215213
sniffio>=1.3.1
216214
sortedcontainers>=2.4.0
217-
starlette>=0.48.0
215+
starlette>=0.50.0
218216
sympy>=1.14.0
219217
tabulate>=0.9.0
220218
tenacity>=9.1.2
@@ -229,32 +227,31 @@ tensorflow>=2.19.1
229227
tensorstore>=0.1.78
230228
termcolor>=3.1.0
231229
tiktoken>=0.12.0
232-
tokamax>=0.0.4
230+
tokamax>=0.0.6
233231
tokenizers>=0.22.1
234232
toml>=0.10.2
235233
tomlkit>=0.13.3
236234
toolz>=1.1.0
237235
tqdm>=4.67.1
238-
transformer-engine-cu12>=2.8.0
239-
transformer-engine-jax>=2.8.0
240-
transformer-engine>=2.8.0
236+
transformer-engine-cu12>=2.9.0
237+
transformer-engine-jax>=2.9.0
238+
transformer-engine>=2.9.0
241239
transformers>=4.57.1
242240
treescope>=0.1.10
243-
triton>=3.5.0
244241
typeguard>=2.13.3
245242
typing-extensions>=4.15.0
246243
typing-inspection>=0.4.2
247244
tzdata>=2025.2
248245
uritemplate>=4.2.0
249246
urllib3>=2.5.0
250247
uvicorn>=0.38.0
251-
virtualenv>=20.35.3
248+
virtualenv>=20.35.4
252249
wadler-lindig>=0.1.7
253250
websockets>=15.0.1
254251
werkzeug>=3.1.3
255252
wheel>=0.45.1
256-
wrapt>=2.0.0
257-
xprof>=2.20.7
253+
wrapt>=2.0.1
254+
xprof>=2.21.1
258255
xxhash>=3.6.0
259256
yarl>=1.22.0
260257
zipp>=3.23.0

0 commit comments

Comments
 (0)