Skip to content

Commit 5489efd

Browse files
Merge pull request #2798 from AI-Hypercomputer:stable_rl
PiperOrigin-RevId: 842376410
2 parents a2d0a19 + b01cd55 commit 5489efd

File tree

5 files changed

+45
-18
lines changed

5 files changed

+45
-18
lines changed

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ RUN pip install aiohttp==3.12.15
2727

2828
RUN pip install numba==0.61.2
2929

30+
# Install Tunix
31+
RUN pip install google-tunix==0.1.5
32+
3033
# Install vLLM for Jax and TPUs
3134
RUN pip install vllm-tpu
3235

dependencies/scripts/docker_build_dependency_image.sh

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,42 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
# Example command:
18-
# bash docker_build_dependency_image.sh MODE=stable
19-
# bash docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}}
20-
# bash docker_build_dependency_image.sh MODE=nightly
21-
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
22-
# Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax:
23-
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
24-
# MODE=custom_wheels is the same as nightly except that it reinstalls any
25-
# additional wheels that are present in the maxtext directory.
26-
# The main use case is to install custom jax or jaxlib wheels but it also
27-
# works with any custom wheels.
28-
# bash docker_build_dependency_image.sh MODE=custom_wheels
29-
30-
# bash docker_build_dependency_image.sh MODE=post-training
31-
# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local
17+
# This script is used to build the MaxText Docker image, supporting
18+
# different environments (stable, nightly) and use cases (pre-training, post-training).
19+
# IMPORTANT: This script must be executed from the root directory of the MaxText repository.
20+
21+
# ==================================
22+
# PRE-TRAINING BUILD EXAMPLES
23+
# ==================================
24+
25+
# Build docker image with stable dependencies
26+
## bash dependencies/scripts/docker_build_dependency_image.sh MODE=stable
27+
## bash dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}}
28+
29+
# Build docker image with nightly dependencies
30+
## bash dependencies/scripts/docker_build_dependency_image.sh MODE=nightly
31+
32+
# Build docker image with stable dependencies and, a pinned JAX_VERSION for TPUs
33+
## bash dependencies/scripts/docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
34+
35+
# Build docker image with stable dependencies and, a pinned JAX_VERSION for GPUs
36+
# Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax
37+
## bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109
38+
39+
# MODE=custom_wheels builds the nightly environment, then reinstalls any
40+
# additional wheels present in the maxtext directory.
41+
# Use this mode to install custom dependencies, such as custom JAX or JAXlib builds.
42+
## bash dependencies/scripts/docker_build_dependency_image.sh MODE=custom_wheels
43+
44+
# ==================================
45+
# POST-TRAINING BUILD EXAMPLES
46+
# ==================================
47+
48+
# Build docker image with stable pre-training dependencies and stable post-training dependencies
49+
## bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
50+
51+
# Build docker image with stable pre-training dependencies and post-training dependencies from GitHub head
52+
## bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local
3253

3354
if [ "${BASH_SOURCE-}" ]; then
3455
this_file="${BASH_SOURCE[0]}"

docs/install_maxtext.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Run the following command, replacing `<jax-build-commit-hash>` with the hash you
9999

100100
```bash
101101
seed-env \
102-
--local-requirements=base_requirements/tpu-base-requirements.txt \
102+
--local-requirements=dependencies/requirements/base_requirements/tpu-base-requirements.txt \
103103
--host-name=MaxText \
104104
--seed-commit=<jax-build-commit-hash> \
105105
--python-version=3.12 \
@@ -113,7 +113,7 @@ Similarly, run the command for the GPU requirements.
113113

114114
```bash
115115
seed-env \
116-
--local-requirements=base_requirements/cuda12-base-requirements.txt \
116+
--local-requirements=dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
117117
--host-name=MaxText \
118118
--seed-commit=<jax-build-commit-hash> \
119119
--python-version=3.12 \

docs/tutorials/posttraining/sft.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ source $VENV_NAME/bin/activate
3737

3838
# 3. Install dependencies in editable mode
3939
uv pip install -e .[tpu] --resolution=lowest
40-
install_maxtext_github_deps
40+
bash tools/setup/setup_post_training_requirements.sh
4141
```
4242

4343
## Setup environment variables

tools/setup/setup_post_training_requirements.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ uv pip uninstall jax jaxlib libtpu
2323

2424
uv pip install aiohttp==3.12.15
2525

26+
# Install Tunix
27+
uv pip install google-tunix==0.1.5
28+
2629
# Install vLLM for Jax and TPUs
2730
uv pip install vllm-tpu
2831

0 commit comments

Comments
 (0)