|
| 1 | +<!-- |
| 2 | + Copyright 2025 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + --> |
| 16 | + |
| 17 | +# Mixture of Experts (MoE) Configuration |
| 18 | + |
| 19 | +This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/MaxText/configs/base.yml` and are primarily used in `src/MaxText/layers/moe.py`. |
| 20 | + |
| 21 | + |
| 22 | +## 1. Architecture |
| 23 | + |
| 24 | +### MoE Strategy |
| 25 | +MaxText supports both Dropless and Dropping strategies. Please refer to the decision tree below to determine the active strategy. |
| 26 | + |
| 27 | + |
| 28 | +*Figure 1: Decision Logic for MaxText MoE Strategies.* |
| 29 | + |
| 30 | +Dropless: |
| 31 | +* [Tokamax Ragged Dot](https://github.com/openxla/tokamax/tree/main/tokamax/_src/ops/ragged_dot): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=True`. |
| 32 | +* [Megablox](https://github.com/google/maxtext/tree/main/src/MaxText/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. |
| 33 | +* [JAX Ragged Dot](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=False`. |
| 34 | +* Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor=-1`. |
| 35 | + |
| 36 | +Dropping: |
| 37 | +* Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor > 0` (commonly 1.0 to 1.25). |
| 38 | + |
| 39 | + |
| 40 | +### General Configuration |
| 41 | +`num_experts`: The total number of routed experts available in the MoE layer. |
| 42 | + |
| 43 | +`num_experts_per_tok`: The number of experts selected for each token, often referred to as top-k strategy. |
| 44 | + |
| 45 | +`shared_experts`: The number of experts that are always active for every token, in addition to the routed experts. |
| 46 | + |
| 47 | +`base_moe_mlp_dim`: The intermediate dimension size for the MLP blocks within the experts. |
| 48 | + |
| 49 | +`interleave_moe_layer_step`: Defines the frequency of MoE layers in transformers. If set to 1, every layer is an MoE layer. If set to X, an MoE layer appears every X layers. |
| 50 | + |
| 51 | +`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced. |
| 52 | + |
| 53 | +`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. |
| 54 | + |
| 55 | +### Routing Mechanism |
| 56 | +`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes. |
| 57 | + |
| 58 | +`n_routing_groups` and `topk_routing_group`: Experts are divided into n_routing_groups. The router first selects the top k highest-scoring groups (as `topk_routing_group`), and then selects experts only from those groups. |
| 59 | + |
| 60 | +`routed_bias`: If enabled, adds a learnable bias term to the gate logits to facilitate load balancing. |
| 61 | + |
| 62 | +`routed_score_func`: Defines the scoring function for the router. |
| 63 | + |
| 64 | +`routed_scaling_factor`: A scalar multiplier applied to the expert weights. |
| 65 | + |
| 66 | +`load_balance_loss_weight`: Sets the coefficient for the auxiliary loss term used to encourage balanced token distribution among experts. |
| 67 | + |
| 68 | +`norm_topk_prob`: If enabled, normalizes the router weights for the selected top-k experts. |
| 69 | + |
| 70 | +### MLP Block & Computation |
| 71 | +`sparse_matmul`: Determines whether to use efficient sparse matrix multiplication or dense matrix multiplication. |
| 72 | + * `True`: Uses specialized kernels (like Tokamax Ragged Dot or Megablox) or JAX Ragged Dot to perform computation only on active tokens. This is generally faster for MoE. |
| 73 | + * `False`: Performs dense computation with masking. This is typically used when checking numerical correctness or implementing dropping strategies. |
| 74 | + |
| 75 | +`use_tokamax_gmm`: If enabled, use Tokamax library's Ragged Dot for matmul. Recommended for dropless configurations. |
| 76 | + |
| 77 | +`megablox`: If enabled, use Megablox for sparse matrix operations. Effective only when `use_tokamax_gmm` is False. |
| 78 | + |
| 79 | +`capacity_factor`: A scalar multiplier for expert capacity. Effective only when `sparse_matmul` is False. |
| 80 | + * Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped. |
| 81 | + * Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline. |
| 82 | + |
| 83 | +`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. |
| 84 | + |
| 85 | +`mlp_bias`: If enabled, add bias terms within the expert MLP layers. |
| 86 | + |
| 87 | +`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications. |
| 88 | + |
| 89 | +## 2. Sharding |
| 90 | +`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include: |
| 91 | + * `fsdp`: Treats the expert axis as a FSDP axis. |
| 92 | + * `context`: Treats the expert axis as a context parallelism axis, useful for long context. |
| 93 | + |
| 94 | +`use_ring_of_experts` (experimental): If enabled, reduces all-to-all communication in EP sharding by circulating tokens in a ring topology rather than sending them all at once. |
| 95 | + |
| 96 | +`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable. |
| 97 | + |
| 98 | +`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding. |
| 99 | + |
| 100 | +## 3. Performance Tuning |
| 101 | +These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel. |
| 102 | + |
| 103 | +* `wi_tile_...`: Tile size for the first layer of the MLP (Input -> Hidden). |
| 104 | +* `wo_tile_...`: Tile size for the second layer of the MLP (Hidden -> Output). |
| 105 | + |
| 106 | +For each, you can control: |
| 107 | +* `..._fwd_...`: Tile size for the forward pass. |
| 108 | +* `..._dlhs_...`: Tile size for the backward pass gradient calculation w.r.t. activations. |
| 109 | +* `..._drhs_...`: Tile size for the backward pass gradient calculation w.r.t. weights. |
| 110 | + |
| 111 | +For each dimension, you can control: |
| 112 | +* `..._batch_seq`: Tile size for batch x sequence dimension. |
| 113 | +* `..._embed_dim`: Tile size for embedding dimension. |
| 114 | +* `..._mlp_dim`: Tile size for MLP dimension. |
| 115 | + |
| 116 | +Implementation Support: |
| 117 | +* Megablox/JAX Ragged Dot: |
| 118 | + * Supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`). |
| 119 | + * Configs are enabled for both INT8, BF8, and BF16. |
| 120 | + |
| 121 | +* Tokamax Ragged Dot: |
| 122 | + * Supports all 18 configurations. **Note**: Currently enabled for FP8 quantization; BF16 integration is in progress. |
0 commit comments