Skip to content

Commit 4f3c12f

Browse files
committed
Updating Pathways-TPU integration with AxLearn
Adding a pathways-tpu extra dependency kind Adding a pathways-tpu docker image
1 parent 1c137ff commit 4f3c12f

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

Dockerfile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
9292
RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
9393
COPY . .
9494

95+
################################################################################
96+
# Pathways-TPU container spec. #
97+
################################################################################
98+
99+
FROM base AS pathways-tpu
100+
101+
ARG EXTRAS=
102+
103+
RUN uv pip install --prerelease=allow .[core,pathways-tpu] && uv cache clean
104+
RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
105+
COPY . .
106+
95107
################################################################################
96108
# GPU container spec. #
97109
################################################################################

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,12 @@ gcp = [
109109
tpu = [
110110
"axlearn[gcp]",
111111
"jax[tpu]==0.5.3", # must be >=0.4.19 for compat with v5p.
112-
"pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator.
112+
]
113+
# For Pathways-TPU single-controller training
114+
pathways-tpu = [
115+
"axlearn[gcp]",
116+
"jax==0.5.3", # must be >=0.4.19 for compat with v5p.
117+
"pathwaysutils==0.1.1",
113118
]
114119
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
115120
vertexai_tensorboard = [

0 commit comments

Comments
 (0)