Skip to content

Commit e4067bf

Browse files
authored
Merge branch 'main' into docs-reorg
2 parents bef18ba + a8499dd commit e4067bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+4129
-807
lines changed

.github/workflows/RunTests.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,22 @@ jobs:
6565

6666
cpu_unit_tests:
6767
needs: tpu_image
68+
strategy:
69+
fail-fast: false
70+
matrix:
71+
worker_group: [1, 2, 3, 4]
6872
uses: ./.github/workflows/run_tests_internal.yml
6973
with:
7074
device_type: cpu
7175
device_name: X64
72-
cloud_runner: linux-x86-n2-16
7376
image_type: tpu
7477
pytest_marker: 'cpu_only'
7578
xla_python_client_mem_fraction: 0.75
7679
tf_force_gpu_allow_growth: false
7780
container_resource_option: "--privileged"
7881
is_scheduled_run: ${{ github.event_name == 'schedule' }}
82+
worker_group: ${{ matrix.worker_group }}
83+
total_workers: 4
7984

8085
tpu_unit_tests:
8186
needs: tpu_image

.github/workflows/build_and_test_maxtext.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
fail-fast: false # don't cancel all jobs on failure
5252
matrix:
5353
image_type: ["py312"]
54+
worker_group: [1, 2, 3, 4]
5455
with:
5556
device_type: cpu
5657
device_name: X64
@@ -61,6 +62,8 @@ jobs:
6162
tf_force_gpu_allow_growth: false
6263
container_resource_option: "--privileged"
6364
is_scheduled_run: ${{ github.event_name == 'schedule' }}
65+
worker_group: ${{ matrix.worker_group }}
66+
total_workers: 4
6467

6568
maxtext_tpu_unit_tests:
6669
needs: build_and_upload_maxtext_package

.github/workflows/run_tests_against_package.yml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ on:
3131
pytest_marker:
3232
required: true
3333
type: string
34-
pytest_addopts:
35-
required: false
36-
type: string
37-
default: ''
3834
is_scheduled_run:
3935
required: true
4036
type: string
@@ -50,12 +46,20 @@ on:
5046
cloud_runner:
5147
required: false
5248
type: string
49+
worker_group:
50+
required: false
51+
type: number
52+
default: 1
53+
total_workers:
54+
required: false
55+
type: number
56+
default: 1
5357

5458
permissions:
5559
contents: read
5660
jobs:
5761
run:
58-
runs-on: ${{ inputs.cloud_runner }}
62+
runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }}
5963
container:
6064
image: gcr.io/tpu-prod-env-multipod/maxtext-unit-test-${{ inputs.device_type == 'cpu' && 'tpu' || inputs.device_type }}:${{ inputs.image_type != '' && inputs.image_type }}
6165
env:
@@ -97,5 +101,10 @@ jobs:
97101
export MAXTEXT_ASSETS_ROOT=$(pwd)/src/MaxText/assets
98102
export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/src/MaxText/test_assets
99103
export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText
104+
# omit this libtpu init args for gpu tests
105+
if [ "${{ inputs.device_type }}" != "cuda12" ]; then
106+
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
107+
fi
100108
# TODO: Fix the skipped tests and remove the deselect flags
101-
.venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" --deselect "tests/tokenizer_test.py::TokenizerTest::test_detokenize"
109+
[ "${{ inputs.total_workers }}" -gt 1 ] && .venv/bin/python3 -m pip install --quiet pytest-split && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }}" || SPLIT_ARGS=""
110+
.venv/bin/python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" --deselect "tests/tokenizer_test.py::TokenizerTest::test_detokenize" $SPLIT_ARGS

.github/workflows/run_tests_internal.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ on:
5050
cloud_runner:
5151
required: false
5252
type: string
53+
worker_group:
54+
required: false
55+
type: number
56+
default: 1
57+
total_workers:
58+
required: false
59+
type: number
60+
default: 1
5361

5462
jobs:
5563
run:
@@ -70,5 +78,7 @@ jobs:
7078
else
7179
FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only"
7280
fi
73-
python3 -m pip install -e . --no-dependencies &&
74-
LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0
81+
python3 -m pip install -e . --no-dependencies
82+
[ "${{ inputs.total_workers }}" -gt 1 ] && python3 -m pip install --quiet pytest-split && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }}" || SPLIT_ARGS=""
83+
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
84+
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 $SPLIT_ARGS

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ We recommend installing MaxText inside a Python virtual environment.
3636
This is the easiest way to get started with the latest stable version.
3737

3838
```bash
39-
# 1. Install uv, a fast Python package installer
39+
# 1. Create virtual environment
40+
uv venv --python 3.12 --seed maxtext_venv
41+
source maxtext_venv/bin/activate
42+
43+
# 2. Install uv, a fast Python package installer
4044
pip install uv
4145

42-
# 2. Install MaxText and its dependencies
46+
# 3. Install MaxText and its dependencies
4347
uv pip install maxtext --resolution=lowest
4448
install_maxtext_github_deps
4549
```
@@ -55,7 +59,11 @@ If you plan to contribute to MaxText or need the latest unreleased features, ins
5559
git clone https://github.com/AI-Hypercomputer/maxtext.git
5660
cd maxtext
5761

58-
# 2. Install dependencies in editable mode
62+
# 2. Create virtual environment
63+
uv venv --python 3.12 --seed maxtext_venv
64+
source maxtext_venv/bin/activate
65+
66+
# 3. Install dependencies in editable mode
5967
pip install uv
6068
# install the tpu package
6169
uv pip install -e .[tpu] --resolution=lowest

base_requirements/requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ google-cloud-aiplatform
1111
google-cloud-monitoring
1212
grain[parquet]
1313
huggingface_hub
14-
jax!=0.7.1, !=0.7.2
15-
jaxlib!=0.7.1, !=0.7.2
14+
jax
15+
jaxlib
1616
jaxtyping
1717
jsonlines
1818
ml-collections
@@ -36,7 +36,8 @@ tensorflow-datasets
3636
tensorflow-text
3737
tensorflow
3838
tiktoken
39+
tokamax
3940
transformers
4041
qwix
41-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip
42+
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
4243
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,125 @@
17141714
),
17151715
)
17161716

1717+
gemma3_12b_32768_v6e256 = _add_to_model_dictionary(
1718+
trillium_model_dict,
1719+
MaxTextModel(
1720+
model_name="gemma3-12b-32768-v6e256",
1721+
model_type="gemma3-12b",
1722+
tuning_params={
1723+
"per_device_batch_size": 1,
1724+
"num_vocab_tiling": 16,
1725+
"ici_fsdp_parallelism": -1,
1726+
"remat_policy": "custom",
1727+
"decoder_layer_input": "device",
1728+
"query_proj": "remat",
1729+
"key_proj": "remat",
1730+
"value_proj": "remat",
1731+
"max_target_length": 32768,
1732+
"attention": "flash",
1733+
"gcs_metrics": True,
1734+
"use_iota_embed": True,
1735+
"dataset_path": "gs://max-datasets-rogue",
1736+
"dataset_type": "synthetic",
1737+
"reuse_example_batch": 1,
1738+
"enable_checkpointing": False,
1739+
"profiler": "xplane",
1740+
"skip_first_n_steps_for_profiler": 10,
1741+
"profiler_steps": 2,
1742+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1743+
"sa_block_q": 1024,
1744+
"sa_block_kv": 1024,
1745+
"sa_block_kv_compute": 1024,
1746+
"sa_block_q_dkv": 512,
1747+
"sa_block_kv_dkv": 2048,
1748+
"sa_block_kv_dkv_compute": 512,
1749+
"sa_block_q_dq": 1024,
1750+
"sa_block_kv_dq": 1024,
1751+
},
1752+
xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)),
1753+
),
1754+
)
1755+
1756+
gemma3_12b_32768_2x_v6e256 = _add_to_model_dictionary(
1757+
trillium_model_dict,
1758+
MaxTextModel(
1759+
model_name="gemma3-12b-32768-2x-v6e256",
1760+
model_type="gemma3-12b",
1761+
tuning_params={
1762+
"per_device_batch_size": 1,
1763+
"num_vocab_tiling": 16,
1764+
"ici_fsdp_parallelism": 1,
1765+
"ici_fsdp_transpose_parallelism": -1,
1766+
"remat_policy": "custom",
1767+
"decoder_layer_input": "device",
1768+
"query_proj": "remat",
1769+
"key_proj": "remat",
1770+
"value_proj": "remat",
1771+
"max_target_length": 32768,
1772+
"attention": "flash",
1773+
"gcs_metrics": True,
1774+
"use_iota_embed": True,
1775+
"dataset_path": "gs://max-datasets-rogue",
1776+
"dataset_type": "synthetic",
1777+
"reuse_example_batch": 1,
1778+
"enable_checkpointing": False,
1779+
"profiler": "xplane",
1780+
"skip_first_n_steps_for_profiler": 10,
1781+
"profiler_steps": 2,
1782+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1783+
"sa_block_q": 1024,
1784+
"sa_block_kv": 1024,
1785+
"sa_block_kv_compute": 1024,
1786+
"sa_block_q_dkv": 512,
1787+
"sa_block_kv_dkv": 2048,
1788+
"sa_block_kv_dkv_compute": 512,
1789+
"sa_block_q_dq": 1024,
1790+
"sa_block_kv_dq": 1024,
1791+
},
1792+
xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)),
1793+
),
1794+
)
1795+
1796+
gemma3_12b_32768_4x_v6e256 = _add_to_model_dictionary(
1797+
trillium_model_dict,
1798+
MaxTextModel(
1799+
model_name="gemma3-12b-32768-4x-v6e256",
1800+
model_type="gemma3-12b",
1801+
tuning_params={
1802+
"per_device_batch_size": 1,
1803+
"num_vocab_tiling": 16,
1804+
"ici_fsdp_parallelism": 1,
1805+
"ici_fsdp_transpose_parallelism": -1,
1806+
"remat_policy": "custom",
1807+
"decoder_layer_input": "device",
1808+
"query_proj": "remat",
1809+
"key_proj": "remat",
1810+
"value_proj": "remat",
1811+
"max_target_length": 32768,
1812+
"attention": "flash",
1813+
"gcs_metrics": True,
1814+
"use_iota_embed": True,
1815+
"dataset_path": "gs://max-datasets-rogue",
1816+
"dataset_type": "synthetic",
1817+
"reuse_example_batch": 1,
1818+
"enable_checkpointing": False,
1819+
"profiler": "xplane",
1820+
"skip_first_n_steps_for_profiler": 10,
1821+
"profiler_steps": 2,
1822+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1823+
"sa_block_q": 1024,
1824+
"sa_block_kv": 1024,
1825+
"sa_block_kv_compute": 1024,
1826+
"sa_block_q_dkv": 512,
1827+
"sa_block_kv_dkv": 2048,
1828+
"sa_block_kv_dkv_compute": 512,
1829+
"sa_block_q_dq": 1024,
1830+
"sa_block_kv_dq": 1024,
1831+
},
1832+
xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)),
1833+
),
1834+
)
1835+
17171836
# Config for Llama3.1 70B model with 131072 max target length aka context length
17181837
llama3_1_70b_131072 = _add_to_model_dictionary(
17191838
trillium_model_dict,

docker_build_dependency_image.sh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# works with any custom wheels.
2828
# bash docker_build_dependency_image.sh MODE=custom_wheels
2929

30-
# bash docker_build_dependency_image.sh MODE=grpo
30+
# bash docker_build_dependency_image.sh MODE=post-training
3131

3232
# Enable "exit immediately if any command fails" option
3333
set -e
@@ -68,17 +68,17 @@ if [[ -z ${MODE} ]]; then
6868
export MODE=stable
6969
echo "Default MODE=${MODE}"
7070
export CUSTOM_JAX=0
71-
export INSTALL_GRPO=0
71+
export INSTALL_POST_TRAINING=0
7272
elif [[ ${MODE} == "custom_wheels" ]] ; then
7373
export MODE=nightly
7474
export CUSTOM_JAX=1
75-
export INSTALL_GRPO=0
76-
elif [[ ${MODE} == "grpo" || ${MODE} == "grpo-experimental" ]] ; then
77-
export INSTALL_GRPO=1
75+
export INSTALL_POST_TRAINING=0
76+
elif [[ ${MODE} == "post-training" || ${MODE} == "post-training-experimental" ]] ; then
77+
export INSTALL_POST_TRAINING=1
7878
export CUSTOM_JAX=0
7979
else
8080
export CUSTOM_JAX=0
81-
export INSTALL_GRPO=0
81+
export INSTALL_POST_TRAINING=0
8282
fi
8383

8484
if [[ -z ${DEVICE} ]]; then
@@ -124,8 +124,8 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
124124
elif [[ ${MANTARAY} == "true" ]]; then
125125
echo "Building with benchmark-db"
126126
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
127-
elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then
128-
echo "Installing MaxText stable mode dependencies for GRPO"
127+
elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then
128+
echo "Installing MaxText stable mode dependencies for Post-Training"
129129
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
130130
else
131131
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
@@ -136,9 +136,9 @@ else
136136
docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
137137
fi
138138

139-
if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
139+
if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
140140
if [[ ${DEVICE} != "tpu" ]] ; then
141-
echo "Error: MODE=grpo is only supported for DEVICE=tpu"
141+
echo "Error: MODE=post-training is only supported for DEVICE=tpu"
142142
exit 1
143143
fi
144144

@@ -158,7 +158,7 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
158158
--network host \
159159
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
160160
--build-arg MODE=${MODE} \
161-
-f ./maxtext_grpo_dependencies.Dockerfile \
161+
-f ./maxtext_post_training_dependencies.Dockerfile \
162162
-t ${LOCAL_IMAGE_NAME} .
163163
fi
164164

docs/guides.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ guides/pallas_kernels_performance.md
3131
guides/understand_logs_and_metrics.md
3232
guides/xprof_user_guide.md
3333
guides/checkpointing_solutions.md
34+
guides/megascale_hang_playbook.md
3435
```

0 commit comments

Comments
 (0)