Skip to content

Commit 8d9a1d6

Browse files
Merge pull request #2402 from melissawm:docs-reorg
PiperOrigin-RevId: 825605905
2 parents 3d51c99 + 05a95b7 commit 8d9a1d6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+401
-496
lines changed

docs/conf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
extensions = [
3737
"myst_nb",
3838
"sphinx_design",
39+
"sphinx_copybutton",
3940
]
4041

4142
templates_path = ["_templates"]
@@ -59,7 +60,7 @@
5960

6061
# Remove specific documents from ToC
6162
exclude_patterns = [
62-
"guides/run_maxtext_via_multihost_job.md",
63-
"guides/run_maxtext_via_multihost_runner.md",
64-
"guides/llm_calculator.ipynb",
63+
"guides/run_maxtext/run_maxtext_via_multihost_job.md",
64+
"guides/run_maxtext/run_maxtext_via_multihost_runner.md",
65+
"explanations/llm_calculator.ipynb",
6566
]

docs/explanations.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
```{toctree}
2020
:maxdepth: 1
2121
22-
explanations/steps_model.md
22+
explanations/jax_ai_libraries_chosen.md
23+
explanations/alternatives.md
24+
explanations/checkpoints.md
2325
explanations/quantization.md
2426
explanations/sharding.md
25-
explanations/data_pipeline_perf.md
2627
explanations/tiling.md
28+
explanations/performance_metrics.md
2729
```
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
-->
1616

17-
# Comparison to Alternatives
17+
# Comparison to alternatives
1818

1919
MaxText is heavily inspired by [MinGPT](https://github.com/karpathy/minGPT)/[NanoGPT](https://github.com/karpathy/nanoGPT), elegant standalone GPT implementations written in PyTorch and targeting Nvidia GPUs. MaxText is more complex, supporting more industry standard models and scaling to tens of thousands of chips. Ultimately MaxText has an MFU more than three times the [17%](https://twitter.com/karpathy/status/1613250489097027584?cxt=HHwWgIDUhbixteMsAAAA) reported most recently with that codebase, is massively scalable and implements a key-value cache for efficient auto-regressive decoding.
2020

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# Checkpoints
1818

19-
## Checkpoint Formats
19+
## Checkpoint formats
2020

2121
Checkpoint formats in MaxText can be categorized along two axes: whether they include **training states** (e.g., optimizer properties) and whether the model's parameter weights are **stacked** or **unstacked** (aka scanned/unscanned). This results in the four types summarized below:
2222

@@ -27,13 +27,13 @@ Checkpoint formats in MaxText can be categorized along two axes: whether they in
2727

2828
We discuss these two axes respectively:
2929

30-
### Training States
30+
### Training states
3131

3232
Checkpoints with a **training state** contain more than just the model's parameter weights. They also include the **optimizer state** (e.g., momentum values), which is essential for resuming a training run exactly where it left off. These "training checkpoints" are typically saved as snapshots during training to allow for recovery if the process is interrupted.
3333

3434
In contrast, **inference checkpoints** contain only the parameter weights. We also call them parameter only/param-only checkpoints. This is the format most commonly used for sharing models on public platforms like HuggingFace, as they are smaller and ready for immediate use in inference or for fine-tuning.
3535

36-
### Stacked Checkpoints and JAX Scan Function
36+
### Stacked checkpoints and JAX scan function
3737

3838
The concept of stacked vs. unstacked checkpoints is specific to JAX-based models that use the `jax.lax.scan` function ([doc](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)). `scan` is a powerful JAX feature that compiles sequential operations (like the layers of a Transformer) into a single, highly optimized kernel, avoiding the overhead of a Python for-loop.
3939

@@ -78,11 +78,11 @@ In MaxText, we treat **Stacked Inference Checkpoints** as the default format for
7878

7979
---
8080

81-
## Using Checkpoints in Practice
81+
## Using checkpoints in practice
8282

8383
Beyond understanding the formats, it's crucial to know how to use checkpoints in your training workflows. MaxText uses flags in the configuration file or on the command line to manage checkpoints.
8484

85-
### Saving Checkpoints During Training
85+
### Saving checkpoints during training
8686

8787
MaxText automatically saves checkpoints periodically during a training run. These are **Stacked Training Checkpoints** that contain the full state needed to resume.
8888

@@ -97,4 +97,6 @@ Furthermore, MaxText supports emergency checkpointing, which saves a local copy
9797
- `local_checkpoint_directory`: The local path for storing emergency checkpoints.
9898
- `local_checkpoint_period`: The interval, in training steps, for saving local checkpoints.
9999

100-
More configs about checkpoints can be found in [here](https://github.com/AI-Hypercomputer/maxtext/blob/518a87037abb2497a2514ff0c8ffc263c69c6f9f/MaxText/configs/base.yml#L23-L65).
100+
More configs about checkpoints can be found in [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/base.yml#L23-L65).
101+
102+
For practical guides on checkpointing, please refer to [](../guides/checkpointing_solutions.md).

docs/guides/jax_ai_libraries_chosen.md renamed to docs/explanations/jax_ai_libraries_chosen.md

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# The JAX Ecosystem in MaxText: An Opinionated Guide
1+
# The JAX ecosystem in MaxText: an opinionated guide
22

33
MaxText is built on a curated stack of JAX libraries, each chosen for a specific purpose. This document provides an opinionated view on *why* MaxText uses the following key components of the JAX ecosystem:
44

@@ -11,11 +11,9 @@ MaxText is built on a curated stack of JAX libraries, each chosen for a specific
1111

1212
This stack isn't just a random collection of tools; it represents a design philosophy centered around **explicitness, composability, and performance at scale**.
1313

14-
1514
This document provides an opinionated view on *why* MaxText uses these specific libraries, explaining the design decisions that make them ideal for building and training large-scale models.
1615

17-
18-
## Flax: For Functional Model Definition
16+
## Flax: For functional model definition
1917

2018
**What is it?** Flax is a high-performance neural network library for JAX that is designed to be flexible, explicit, and easy to use.
2119

@@ -27,8 +25,7 @@ With its latest generation API, NNX, Flax provides a modern, object-oriented (OO
2725

2826
For more information on using Flax, please refer to https://github.com/google/flax
2927

30-
31-
## Optax: For Composable Optimization
28+
## Optax: For composable optimization
3229

3330
**What is it?** Optax is a gradient processing and optimization library for JAX. It reimagines the optimizer as a series of composable functional transformations.
3431

@@ -38,8 +35,7 @@ For more information on using Flax, please refer to https://github.com/google/fl
3835

3936
For more information on using Optax, please refer to https://github.com/google-deepmind/optax
4037

41-
42-
## Orbax: For Robust Checkpointing
38+
## Orbax: For robust checkpointing
4339

4440
**What is it?** Orbax is a library for checkpointing JAX programs, designed for large-scale, potentially unreliable environments.
4541

@@ -54,8 +50,7 @@ For massive models, saving and loading state is a critical part of the training
5450

5551
For more information on using Orbax, please refer to https://github.com/google/orbax
5652

57-
58-
## Grain: For Deterministic, Multi-Host Data Loading
53+
## Grain: For deterministic, multi-host data loading
5954

6055
**What is it?** Grain is a high-performance data loading library designed for deterministic, global shuffle and multi-host data loading.
6156

@@ -67,8 +62,7 @@ Its APIs are explicitly designed for the multi-host paradigm, simplifying the pr
6762

6863
For more information on using Grain, please refer to https://github.com/google/grain and the grain guide in maxtext located at https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md
6964

70-
71-
## Qwix: For Native JAX Quantization
65+
## Qwix: For native JAX quantization
7266

7367
**What is it?** Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ)
7468

@@ -79,8 +73,7 @@ We chose Qwix because it provides the necessary primitives **natively within the
7973

8074
For more information on how to quantize your model using Qwix, please refer to https://github.com/google/qwix
8175

82-
83-
## Tunix: For Comprehensive Post-Training
76+
## Tunix: For comprehensive post-training
8477

8578
**What is it?** Tunix is a JAX-based library designed for a wide range of post-training tasks, including Supervised Fine-Tuning (SFT), Reinforcement Learning (RL), and Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA.
8679

@@ -95,4 +88,4 @@ MaxText leverages Tunix as its core library for post-training, offering a unifie
9588

9689
We chose Tunix because it provides a **comprehensive, performant, and JAX-native solution for the entire post-training lifecycle**. Its integration with libraries like vLLM and its alignment with the NNX ecosystem make it a powerful tool for both full model adaptation and parameter-efficient tuning.
9790

98-
For more information on using Tunix, please refer to https://github.com/google/tunix
91+
For more information on using Tunix, please refer to https://github.com/google/tunix
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
limitations under the License.
1515
-->
1616

17-
# Performance Metrics
17+
(performance-metrics)=
18+
# Performance metrics
1819

1920
## MFU
2021

@@ -57,11 +58,11 @@ $$
5758

5859
Hence, MFU is the fraction of peak hardware performance actually utilized by the model, and can be expressed in different units — step time, throughput, or raw flops/s.
5960

60-
### MaxText Calculating + Reporting
61-
In MaxText, we sum all of the matmuls performed in one step, see [calculate_tflops_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/e969faabbb571285a51545530f34d8f0a9f237e9/MaxText/maxtext_utils.py#L297)
62-
and divide it by the measured (via python `time.time()`) step time. In each step we print the resulting Model Flops per second [`per_device_tflops_per_sec`](https://github.com/AI-Hypercomputer/maxtext/blob/e969faabbb571285a51545530f34d8f0a9f237e9/MaxText/metric_logger.py#L193-L194). One can calculate the MFU by dividing this number by the peak tflops of the hardware (e.g., $918e^{12}$ FLOPS/s for Trillium).
61+
### MaxText calculating + reporting
62+
In MaxText, we sum all of the matmuls performed in one step, see [calculate_tflops_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/maxtext_utils.py#L454)
63+
and divide it by the measured (via python `time.time()`) step time. In each step we print the resulting Model Flops per second [`per_device_tflops_per_sec`](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/metric_logger.py#L211-L213). One can calculate the MFU by dividing this number by the peak tflops of the hardware (e.g., $918e^{12}$ FLOPS/s for Trillium).
6364

64-
### Causal Attention
65+
### Causal attention
6566
Due to causality only half of the (query, key) pairs need to be computed, those with query_idx >= key_idx. This accounts for the fact only prior tokens can be used to predict future ones. Prior to https://github.com/AI-Hypercomputer/maxtext/pull/1988 MaxText did not account for sparsity for theoretical flops, and used
6667

6768
Attention Flops ~= 4 * sequence^2 * batch * heads * head_dim
@@ -98,6 +99,6 @@ $$\begin{align*}
9899

99100
This shows any of step time, tokens/s or MFU can be used to determine how long training will take and are proportionally (or inversely proportionally) related. MFU is most useful to compare across different models/hardwares and while optimizing performance, whereas step time or tokens/second may be more useful when these are fixed.
100101

101-
## Why not Hardware Flops?
102+
## Why not hardware flops?
102103

103-
Hardware (e.g., XLA reported) FLOPs do not accurately reflect computation efficiency as they depend on the program / implementation, not just on the model and its inherent computations (higher hardware FLOPs does not necessarily mean less room for improvement). For example, they include remat and potentially auxiliary operations (such as reshaping for dropping moe [here](https://github.com/AI-Hypercomputer/maxtext/blob/4b6142950aff5d9ba42d830efc5ce4c4ac9d4135/MaxText/layers/moe.py#L1267)), which are an implementation detail and not part of the model. In addition, XLA reported FLOPs may not be accurate with pallas kernels. Hardware flops utilization is not (inversely) proportional to step time as opposed to MFU, since hardware flops can change with implementation details like remat policies.
104+
Hardware (e.g., XLA reported) FLOPs do not accurately reflect computation efficiency as they depend on the program / implementation, not just on the model and its inherent computations (higher hardware FLOPs does not necessarily mean less room for improvement). For example, they include remat and potentially auxilliary operations (such as reshaping for dropping moe [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/layers/moe.py#L1544)), which are an implementation detail and not part of the model. In addition, XLA reported FLOPs may not be accurate with pallas kernels. Hardware flops utilization is not (inversely) proportional to step time as opposed to MFU, since hardware flops can change with implementation details like remat policies.

docs/explanations/steps_model.md

Lines changed: 0 additions & 33 deletions
This file was deleted.

docs/guides.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,18 @@
1919
```{toctree}
2020
:maxdepth: 1
2121
22-
guides/checkpoints.md
22+
guides/run_maxtext.md
2323
guides/custom_model.md
24-
guides/run_maxtext_localhost.md
25-
guides/run_maxtext_via_xpk.md
26-
guides/run_maxtext_via_pathways.md
2724
guides/data_input_pipeline.md
28-
guides/single_host_gpu.md
2925
guides/knowledge_distillation.md
3026
guides/gcp_workload_observability.md
3127
guides/monitor_goodput.md
3228
guides/use_vertex_ai_tensorboard.md
3329
guides/features_and_diagnostics.md
3430
guides/pallas_kernels_performance.md
35-
guides/performance_metrics.md
3631
guides/understand_logs_and_metrics.md
37-
guides/checkpointing_solutions/gcs_checkpointing.md
38-
guides/checkpointing_solutions/emergency_checkpointing.md
39-
guides/checkpointing_solutions/multi_tier_checkpointing.md
40-
guides/jax_ai_libraries_chosen.md
4132
guides/xprof_user_guide.md
33+
guides/checkpointing_solutions.md
4234
guides/megascale_hang_playbook.md
4335
guides/multimodal.md
4436
```
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Checkpointing solutions
2+
3+
```{toctree}
4+
:maxdepth: 1
5+
6+
checkpointing_solutions/gcs_checkpointing.md
7+
checkpointing_solutions/emergency_checkpointing.md
8+
checkpointing_solutions/multi_tier_checkpointing.md
9+
```

0 commit comments

Comments
 (0)