Skip to content

Commit 6f47311

Browse files
Merge pull request #2754 from AI-Hypercomputer:custom_model_doc
PiperOrigin-RevId: 840576308
2 parents cf75890 + 430bb7e commit 6f47311

File tree

1 file changed

+116
-24
lines changed

1 file changed

+116
-24
lines changed

docs/guides/optimization/custom_model.md

Lines changed: 116 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ This document provides a guide to optimize and customize your LLM model configur
2424

2525
To begin, identify your model's size, review open-source model configs, and establish the initial configurations for each block. You can use our [reference calculator (on Colab)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/docs/explanations/llm_calculator.ipynb) to estimate parameters and FLOPs for dense, Mixtral-like Mixture of Experts (MoE), and DeepSeek-like MoE models to help you estimate the parameter count and FLOPs.
2626

27-
Based on resources like [Language Modeling from Scratch](https://github.com/stanford-cs336/spring2025-lectures/blob/e9cb2488fdb53ea37f0e38924ec3a1701925cef3/nonexecutable/2025%20Lecture%203%20-%20architecture.pdf), we observe common architectural ratios for dense models, as shown below:
27+
Based on resources like [Language Modeling from Scratch](https://github.com/stanford-cs336/spring2025-lectures/blob/e9cb2488fdb53ea37f0e38924ec3a1701925cef3/nonexecutable/2025%20Lecture%203%20-%20architecture.pdf), common architectural ratios include:
28+
29+
Dense models
2830

2931
* `mlp_dim / emb_dim`: 2.5-4
3032
* `head_dim * num_query_heads / emb_dim`: 1-2
3133
* `emb_dim / num_decoder_layers`: 100-200
3234

33-
For MoE models,
35+
MoE models
3436

3537
* sparsity (`num_experts / num_experts_per_tok`): 4-32
3638
* `moe_mlp_dim / emb_dim`: 0.3-3
@@ -39,17 +41,32 @@ For MoE models,
3941

4042
### Model configs
4143

42-
To unlock peak performance on [TPUs](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm), it is critical to keep the Matrix Multiply Unit (MXU) fully utilized. The MXU is the primary computational engine, with the Trillium chip specifically optimized for 256×256 matrix multiplications (earlier TPU versions, like v4/v5e/v5p, are optimized for 128×128 operations). Processing smaller matrix multiplications (e.g., two 128×128 operations on Trillium) will halve the efficiency compared to a single, fully-utilized 256×256 operation.
44+
To unlock peak performance on [TPUs](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm), it is critical to keep the Matrix Multiply Unit (MXU) fully utilized. The MXU is the primary computational engine, with the [Trillium](https://docs.cloud.google.com/tpu/docs/v6e) and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) chips specifically optimized for 256×256 matrix multiplications (earlier TPU versions, like [v4](https://docs.cloud.google.com/tpu/docs/v4)/[v5e](https://docs.cloud.google.com/tpu/docs/v5e)/[v5p](https://docs.cloud.google.com/tpu/docs/v5p), are optimized for 128×128 operations). Processing smaller matrix multiplications (e.g., two 128×128 operations on Trillium and Ironwood) will halve the efficiency compared to a single, fully-utilized 256×256 operation.
4345

4446
Therefore, for optimal efficiency:
4547

46-
* Model and MLP Dimensions: Design your model's emb_dim and mlp_dim to be multiples of 256 (for Trillium) or 128 (for older TPUs).
47-
* Self-Attention Head Dimension: Ensure your attention head_dim are also multiples of 256 (for Trillium) or 128 (for older TPUs).
48+
* Model and MLP Dimensions: Design your model's emb_dim and mlp_dim to be multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
49+
* Self-Attention Head Dimension: Ensure your attention head_dim are also multiples of 256 (for Trillium and Ironwood) or 128 (for older TPUs).
4850

4951
Generally, larger multiples are more efficient. If achieving these specific multiples isn't possible, prioritize dimensions to a multiple of either 8 or 128 to help the XLA compiler optimize memory and computation.
5052

5153
To achieve efficient memory usage on a TPU, configure your training with the largest batch size that fits within its memory limits (configure a rematerialization policy with offloading to achieve the best MFU). Each TPU core leverages internal 8×128 vector registers for highly optimized matrix multiplications. Therefore, for peak performance and to minimize padding, your batch size should ideally be a multiple of 128. If a multiple of 128 is not feasible, try a multiple of 8. For more detailed explanations, see this [performance guide](https://cloud.google.com/tpu/docs/performance-guide).
5254

55+
### Ironwood
56+
57+
Ironwood is engineered for cutting-edge, large-scale AI model training and inference. To unlock its full potential, the primary goal is to continuously supply data to its powerful TensorCores, preventing bottlenecks from memory or the Inter-Chip Interconnect (ICI).
58+
59+
We have published optimized recipes for models like DeepSeek v3, GPT-OSS, Qwen3, and Llama3 on Ironwood, covering both BF16 and FP8 precision, available in this [guide](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/training/ironwood).
60+
61+
Key strategies to maximize performance on Ironwood include:
62+
* Adopt FP8 Precision: Ironwood delivers 2x throughput with FP8 compared to BF16. Design models to use mixed-precision training, employing FP8 for weights and activations where possible to maximize computational speed.
63+
* Offload to SparseCores: Ironwood's enhanced SparseCores are crucial for efficiency. Offloading collective communication and data management to keep TensorCores focused on compute.
64+
* Leverage the dual-chiplet architecture: Each Ironwood chip contains two TensorCores with an ultra-fast interconnect (die-to-die, 6x faster than 1D ICI link).
65+
66+
Given Ironwood's high compute power, communication bandwidth can easily become the limiting factor. To address this:
67+
* Enable SparseCore offloading for collectives: By setting the appropriate [XLA flags](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/benchmarks/xla_flags_library.py#L70-L116), you can offload collective operations (like All-Reduce, All-Gather, etc.) to the SparseCores. These operations then run in parallel with the TensorCore computations, effectively hiding communication latency and improving Model Flop Utilization (MFU).
68+
* Optimize sharding strategies: Align your model distribution with the hardware topology. Choose sharding strategies (e.g., data, tensor, pipeline parallelism) that minimize data transfer over the ICI and maximize the overlap between computation and communication.
69+
5370
### Performance configs
5471

5572
Use these general runtime configurations to improve your model's performance.
@@ -58,18 +75,21 @@ Use these general runtime configurations to improve your model's performance.
5875

5976
* **Flash Attention**. Use the largest possible block size to maximize throughput.
6077

61-
* **Memory usage**. To free up memory with large models, use custom remat policy to offload layer activations (including inputs, query, key, value, out projections, etc) to the host CPU.
78+
* **Memory usage**. To free up memory with large models, use custom remat policy to offload layer activations (including inputs, attention, and MLP blocks) to the host CPU.
6279

63-
* **Compiler flags**. XLA is the backend compiler for TPUs. Many critical performance settings can be controlled directly through XLA flags. We suggest beginning with the proven flags we have tested and provided [here](https://github.com/AI-Hypercomputer/maxtext/blob/02b6b8d2558f7dab7d2be024783977bdbb3ed251/benchmarks/xla_flags_library.py).
80+
* **Compiler flags**. XLA is the backend compiler for TPUs. Many critical performance settings can be controlled directly through XLA flags. We suggest beginning with the proven flags we have tested and provided [here](https://github.com/AI-Hypercomputer/maxtext/blob/b53bf3bef6b54b1d4939a4b700bc11fe149d1128/benchmarks/xla_flags_library.py).
6481

6582
* **Benchmark**. For consistent speed tests, set `reuse_example_batch=1` to repeatedly use the same data batch, isolating computation speed from data loading. Or use on-the-fly generated data by setting `dataset_type=synthetic`.
6683

67-
(roofline-sharding)=
6884
## Step 3. Choose efficient sharding strategies using Roofline Analysis
6985

70-
To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for Trillium that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding_on_TPUs) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/).
86+
To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for [v5p](https://docs.cloud.google.com/tpu/docs/v5p), [Trillium](https://docs.cloud.google.com/tpu/docs/v6e), and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/).
7187

72-
For the calculation below on Trillium, we will use Arithmetic Intensity (AI) of 5100 for 2 ICI links bandwidth bandwidth (1D with wrapound or 2D without wraparound) and 2500 for 4 ICI links bandwidth (2D with wraparound on both dimensions) over the ICI. The later bandwidth is particularly for Trillium v6e-256 (16x16) with wraparound connection.
88+
| TPU Type | ICI Arithmetic Intensity |
89+
|---|---|
90+
| v5p | 2550 for 1D-ICI |
91+
| Trillium | 5100 for 1D-ICI (1D with wrapound or 2D without wraparound) <br> 2550 for 2D-ICI (2D with wraparound on both dimensions), particularly for v6e-256 |
92+
| Ironwood | 12800 for 1D-ICI|
7393

7494
### Fully Sharded Data Parallelism (FSDP)
7595

@@ -81,11 +101,22 @@ FSPD AI: `global batch / sparsity` (`sparsity = num_experts / num_experts_per_to
81101

82102
**Example with a sparsity of 16**:
83103
* `global batch / sparsity > hardware AI`
84-
* `global batch / 16 > 2500` (16x16 with wraparound)
104+
105+
v5p:
106+
* `global batch / 16 > 2550`
107+
* `global batch > 40k` (in tokens)
108+
109+
Trillium:
110+
* `global batch / 16 > 2550` (16x16 with wraparound)
85111
* `global batch > 40k` (in tokens)
86112

87113
We also need a single layer of weights to fit into memory which can be an issue for medium/large MoE models, e.g. DeepSeek has roughly 10B params per layer, which corresponds to 40GiB of bf16 weights and gradients, which will not fit into Trillium’s 32GiB of HBM. So the use of pure FSDP on Trillium is feasible for models with layers not exceeding roughly 5B parameters. For these larger models need Expert or Tensor Parallelism.
88114

115+
Ironwood:
116+
* `global batch / 16 > 12800`
117+
* `global batch > 205k` (in tokens)
118+
119+
89120
#### Mix FSDP
90121

91122
For sparse models, large models, or when scaling to a large number of chips FSDP can be used in conjunction with other sharding strategies, such as Expert Parallelism (EP), Tensor Parallelism (TP), and Pipeline Parallelism (PP).
@@ -94,10 +125,20 @@ The same AI as derived in the Pure FSDP section above still hold, we need `globa
94125

95126
**Example with EP=16, FSDP=16, and sparsity=32**:
96127
* `pdb * EP / sparsity > hardware AI`
128+
129+
v5p:
130+
* `pdb * 16 / 32 > 2550`
131+
* `pdb > 2550 * 32 / 16 = 5k` (in tokens)
132+
133+
Trillium:
97134
* `pdb * 16 / 32 > 5100`
98-
* `pdb > 5100 * 32 / 16 = 10200` (in tokens)
135+
* `pdb > 5100 * 32 / 16 = 10k` (in tokens)
99136

100-
We need a per device batch of at least 10200 in this case.
137+
Ironwood:
138+
* `pdb * 16 / 32 > 12800`
139+
* `pdb > 12800 * 32 / 16 = 26k` (in tokens)
140+
141+
We need a per device batch of at least 5k for v5p, 10k for Trillium, and 26k for Ironwood in this case.
101142

102143
### Expert Parallelism (EP)
103144

@@ -106,14 +147,20 @@ If pure FSDP doesn’t work either due to AI or to fit in layer weights, EP is g
106147
AI of 1D EP on ICI rings `= 4 * mlp_dim / EP`. Communication cost of all-to-all is roughly 1/4 of all-gather and reduce-scatter.
107148

108149
**Example with EP=4**
150+
151+
v5p:
152+
* `4 * M > 2550 * 4`
153+
* `M > 2.5k`
154+
155+
Trillium:
109156
* `4 * M > 5100 * 4`
110-
* `M > 5,100 * 4 = 5,100`
157+
* `M > 5k`
111158

112-
**Example with EP=16**
113-
* `4 * M > 5,100 * 16`
114-
* `M > 5,100 * 4 = 20,400`
159+
Ironwood:
160+
* `4 * M > 12800 * 4`
161+
* `M > 13k`
115162

116-
These examples show that to use EP, we need a large enough mlp dimension.
163+
These examples show that to use EP, we need a large enough MLP dimension.
117164

118165
It's important to note that this is only a roofline analysis. A nocap strategy with a high degree of EP introduces additional overhead - load balancing across expert groups becomes more challenging.
119166

@@ -125,25 +172,70 @@ AI of TP: M / TP
125172

126173
**Example with TP=4**
127174
* `M / TP > hardware AI`
175+
176+
v5p:
177+
* `M / 4 > 2550`
178+
* `M > 10k`
179+
180+
Trillium:
128181
* `M / 4 > 5100`
129-
* `M > 20400`
182+
* `M > 20k`
130183

131-
We have seen in practice M should be even larger- ideally 40k+. This is what we use for Llama-405B (M=53k), and was used for a custom sparse 10T model (M=40k, 64 experts).
184+
We have seen in practice M should be even larger - ideally 40k+. This is what we use for Llama-405B (M=53k), and was used for a custom sparse 10T model (M=40k, 64 experts). TP=4 corresponds to a custom Trillium mesh, an 8x8 ring of 2x2 subrings (the TP communication operates on the 2x2 ring). This 2x2 ring performs well (near roofline), but the 8x8 rings perform poorly (0.5 x 1 axis). E.g. if we use FSDP=64, TP=4, the FSDP=64 communications will be slower than the hardware ICI roofline, so we prefer to use the full 16 axis when M is large enough.
132185

133-
TP=4 corresponds to a custom Trillium mesh, an 8x8 ring of 2x2 subrings (the TP communication operates on the 2x2 ring). This 2x2 ring performs well (near roofline), but the 8x8 rings perform poorly (0.5 x 1 axis). E.g. if we use FSDP=64, TP=4, the FSDP=64 communications will be slower than the hardware ICI roofline, so we prefer to use the full 16 axis when M is large enough.
186+
Ironwood:
187+
* `M / 4 > 12800`
188+
* `M > 51k`
134189

135190
**Example with TP=16**
136191
* `M / TP > hardware AI`
192+
193+
v5p:
194+
* `M / 16 > 2550`
195+
* `M > 41k`
196+
197+
Trillium:
137198
* `M / 16 > 5100`
138-
* `M > 81600`
199+
* `M > 82k`
139200

140201
To use TP=16, we need M > 80k (ideally larger, 100k+). We have used this in a custom dense model (900B, M=131k), which performs very well even at 1k per device tokens (scaling to 25k+ with a reasonable global batch).
141202

203+
### Pipeline Parallelism (PP)
204+
205+
Pipeline Parallelism is advantageous when global batch size limits per device batch size, making Data Parallelism (DP) inefficient. PP is associated with small communication costs since it only needs to permute the small layer inputs.
206+
207+
AI of PP: 3/2 * layers_per_pipeline_stage * M * num_experts_per_tok
208+
209+
**Example with PP=16, layers_per_pipeline_stage=1, num_experts_per_tok=8**
210+
* `layers_per_pipeline_stage * M * num_experts_per_tok > hardware AI`
211+
212+
v5p - PP over ICI:
213+
* `3 * M * 8 / 2 > 2550`
214+
* `M > 210`
215+
216+
v5p - PP over DCN:
217+
* `3 * M * 8 / 2 > 73000`
218+
* `M > 6k`
219+
220+
Trillium over ICI:
221+
* `3 * M * 8 / 2 > 5100`
222+
* `M > 420`
223+
224+
Trillium over DCN:
225+
* `3 * M * 8 / 2 > 73000`
226+
* `M > 6k`
227+
228+
Ironwood over ICI:
229+
* `3 * M * 8 / 2 > 12800`
230+
* `M > 1100`
231+
232+
It is important to emphasize that this is a theoretical roofline analysis. Real-world performance will depend on the efficiency of the implementation and XLA compilation on the TPU. Refer to the [link](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/sharding.md#pp--fsdpdp) for specific challenges regarding PP + FSDP/DP.
233+
142234
## Step 4. Analyze experiments
143235

144236
With your configs, begin experimenting to evaluate the model's performance. We strongly recommend capturing a profile by following these [instructions](https://docs.jax.dev/en/latest/profiling.html#). If you are using MaxText, this can be done by simply setting `profiler=xplane` in your configuration.
145237

146-
After generating the profile, use a tool, like [xprof](https://github.com/openxla/xprof), [xprofiler](https://github.com/AI-Hypercomputer/cloud-diagnostics-xprof), or [tensorboard](https://github.com/tensorflow/tensorboard) to analyze the results. This example ([Profile TPU Programs](https://jax-ml.github.io/scaling-book/profiling/) can serve as your guide. A key principle for maximizing training throughput is to ensure you are fully utilizing the available HBM. Once you achieve satisfactory performance, you can proceed with full training runs. Continue to analyze your model and refine your configurations as needed.
238+
After generating the profile, use a tool, like [xprof](https://github.com/openxla/xprof), [xprofiler](https://github.com/AI-Hypercomputer/cloud-diagnostics-xprof), or [tensorboard](https://github.com/tensorflow/tensorboard) to analyze the results. This example ([Profile TPU Programs](https://jax-ml.github.io/scaling-book/profiling/)) can serve as your guide. A key principle for maximizing training throughput is to ensure you are fully utilizing the available HBM. Once you achieve satisfactory performance, you can proceed with full training runs. Continue to analyze your model and refine your configurations as needed.
147239

148240
## Example of dense model
149241

@@ -209,4 +301,4 @@ Objective was to demonstrate achieving reasonable MFU on a low batch setting (2k
209301
| **Total Params** | 1.04E+13 |
210302
| **Active Params** | 3.76E+11 |
211303
| **MFU (1 pod Trillium)** | 34.5% |
212-
| **MFU(16 pods Trillium)** | 26.2% |
304+
| **MFU (16 pods Trillium)** | 26.2% |

0 commit comments

Comments
 (0)