Skip to content

Commit cf75890

Browse files
Merge pull request #2758 from AI-Hypercomputer:moe_strategy_doc
PiperOrigin-RevId: 840573889
2 parents 372b7c1 + 9011dd9 commit cf75890

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

docs/_static/moe_strategy.png

330 KB
Loading

docs/reference/core_concepts.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ core_concepts/alternatives.md
2424
core_concepts/quantization.md
2525
core_concepts/tiling.md
2626
core_concepts/jax_xla_and_pallas.md
27+
core_concepts/moe_configuration.md
2728
```
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
<!--
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# Mixture of Experts (MoE) Configuration
18+
19+
This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/MaxText/configs/base.yml` and are primarily used in `src/MaxText/layers/moe.py`.
20+
21+
22+
## 1. Architecture
23+
24+
### MoE Strategy
25+
MaxText supports both Dropless and Dropping strategies. Please refer to the decision tree below to determine the active strategy.
26+
27+
![Illustration of MoE strategy](../../_static/moe_strategy.png)
28+
*Figure 1: Decision Logic for MaxText MoE Strategies.*
29+
30+
Dropless:
31+
* [Tokamax Ragged Dot](https://github.com/openxla/tokamax/tree/main/tokamax/_src/ops/ragged_dot): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=True`.
32+
* [Megablox](https://github.com/google/maxtext/tree/main/src/MaxText/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`.
33+
* [JAX Ragged Dot](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=False`.
34+
* Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor=-1`.
35+
36+
Dropping:
37+
* Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor > 0` (commonly 1.0 to 1.25).
38+
39+
40+
### General Configuration
41+
`num_experts`: The total number of routed experts available in the MoE layer.
42+
43+
`num_experts_per_tok`: The number of experts selected for each token, often referred to as top-k strategy.
44+
45+
`shared_experts`: The number of experts that are always active for every token, in addition to the routed experts.
46+
47+
`base_moe_mlp_dim`: The intermediate dimension size for the MLP blocks within the experts.
48+
49+
`interleave_moe_layer_step`: Defines the frequency of MoE layers in transformers. If set to 1, every layer is an MoE layer. If set to X, an MoE layer appears every X layers.
50+
51+
`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced.
52+
53+
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability.
54+
55+
### Routing Mechanism
56+
`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes.
57+
58+
`n_routing_groups` and `topk_routing_group`: Experts are divided into n_routing_groups. The router first selects the top k highest-scoring groups (as `topk_routing_group`), and then selects experts only from those groups.
59+
60+
`routed_bias`: If enabled, adds a learnable bias term to the gate logits to facilitate load balancing.
61+
62+
`routed_score_func`: Defines the scoring function for the router.
63+
64+
`routed_scaling_factor`: A scalar multiplier applied to the expert weights.
65+
66+
`load_balance_loss_weight`: Sets the coefficient for the auxiliary loss term used to encourage balanced token distribution among experts.
67+
68+
`norm_topk_prob`: If enabled, normalizes the router weights for the selected top-k experts.
69+
70+
### MLP Block & Computation
71+
`sparse_matmul`: Determines whether to use efficient sparse matrix multiplication or dense matrix multiplication.
72+
* `True`: Uses specialized kernels (like Tokamax Ragged Dot or Megablox) or JAX Ragged Dot to perform computation only on active tokens. This is generally faster for MoE.
73+
* `False`: Performs dense computation with masking. This is typically used when checking numerical correctness or implementing dropping strategies.
74+
75+
`use_tokamax_gmm`: If enabled, use Tokamax library's Ragged Dot for matmul. Recommended for dropless configurations.
76+
77+
`megablox`: If enabled, use Megablox for sparse matrix operations. Effective only when `use_tokamax_gmm` is False.
78+
79+
`capacity_factor`: A scalar multiplier for expert capacity. Effective only when `sparse_matmul` is False.
80+
* Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
81+
* Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
82+
83+
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul.
84+
85+
`mlp_bias`: If enabled, add bias terms within the expert MLP layers.
86+
87+
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications.
88+
89+
## 2. Sharding
90+
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
91+
* `fsdp`: Treats the expert axis as a FSDP axis.
92+
* `context`: Treats the expert axis as a context parallelism axis, useful for long context.
93+
94+
`use_ring_of_experts` (experimental): If enabled, reduces all-to-all communication in EP sharding by circulating tokens in a ring topology rather than sending them all at once.
95+
96+
`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
97+
98+
`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding.
99+
100+
## 3. Performance Tuning
101+
These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel.
102+
103+
* `wi_tile_...`: Tile size for the first layer of the MLP (Input -> Hidden).
104+
* `wo_tile_...`: Tile size for the second layer of the MLP (Hidden -> Output).
105+
106+
For each, you can control:
107+
* `..._fwd_...`: Tile size for the forward pass.
108+
* `..._dlhs_...`: Tile size for the backward pass gradient calculation w.r.t. activations.
109+
* `..._drhs_...`: Tile size for the backward pass gradient calculation w.r.t. weights.
110+
111+
For each dimension, you can control:
112+
* `..._batch_seq`: Tile size for batch x sequence dimension.
113+
* `..._embed_dim`: Tile size for embedding dimension.
114+
* `..._mlp_dim`: Tile size for MLP dimension.
115+
116+
Implementation Support:
117+
* Megablox/JAX Ragged Dot:
118+
* Supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`).
119+
* Configs are enabled for both INT8, BF8, and BF16.
120+
121+
* Tokamax Ragged Dot:
122+
* Supports all 18 configurations. **Note**: Currently enabled for FP8 quantization; BF16 integration is in progress.

src/MaxText/configs/base.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ load_balance_loss_weight: 0.01 # weight for the load balance loss
179179
use_random_routing: False # whether to use random routing for debug/test purpose
180180
use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops
181181
use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism
182-
# Tunable tiling dimensions used for MLP GMM, includes Tokamax ragged_dot and Megablox
182+
# Tunable tiling dimensions used for MLP GMM
183+
# Megablox/JAX Ragged Dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
184+
# Tokamax Ragged Dot - supports all 18 configs
183185
wi_tile_fwd_batch_seq: 512
184186
wi_tile_fwd_embed_dim: 1024
185187
wi_tile_fwd_mlp_dim: 1024

0 commit comments

Comments
 (0)