diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 1f6473a08..249a65b9d 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -167,7 +167,6 @@ def __call__( else: attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - # MLP block. mlp_lnx = linears.mlp_block( in_features=lnx.shape[-1], intermediate_dim=cfg.mlp_dim, @@ -253,13 +252,140 @@ def __call__( page_state=page_state, ) if self.config.scan_layers: - inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). + inputs = inputs[0] if self.config.scan_layers: return inputs, None # pytype: disable=bad-return-type else: return inputs +class SequentialNNXWrapper(nnx.Module): + """Wrapper that creates sequential decoder layers for pipeline stages. + + This wrapper matches the decoder layer signature expected by Pipeline. + """ + + def __init__( + self, + decoder_layer_class: type, + num_decoder_layers: int, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | Quant = None, + ): + """Initialize wrapper with sequential decoder layers. + + Args: + decoder_layer_class: NNX decoder layer class to instantiate + num_decoder_layers: Number of layers to create + config: Model configuration + mesh: Device mesh + model_mode: 'train', 'eval', etc. + rngs: RNG state + quant: Quantization config + """ + self.sequential = SequentialBlockNNXDecoderLayers( + decoder_layer_class=decoder_layer_class, + num_decoder_layers=num_decoder_layers, + config=config, + mesh=mesh, + model_mode=model_mode, + rngs=rngs, + quant=quant + ) + + def __call__(self, *args, **kwargs): + """Forward pass through sequential layers.""" + return self.sequential(*args, **kwargs) + + +class SequentialBlockNNXDecoderLayers(nnx.Module): + """Sequential unscanned series of NNX decoder layers.""" + + def __init__( + self, + decoder_layer_class: type, + num_decoder_layers: int, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | Quant = None, + ): + """Initialize multiple NNX decoder layer instances. + + Args: + decoder_layer_class: The NNX decoder layer class to instantiate + num_decoder_layers: Number of decoder layers to create + config: Model configuration + mesh: Device mesh for sharding + model_mode: 'train', 'eval', etc. + rngs: RNG state for initialization + quant: Quantization configuration + """ + self.config = config + self.num_decoder_layers = num_decoder_layers + + # Store layer instances as attributes for NNX pytree tracking. + for lyr in range(num_decoder_layers): + layer = decoder_layer_class( + config=config, + mesh=mesh, + model_mode=model_mode, + rngs=rngs, + quant=quant, + ) + setattr(self, f'layer_{lyr}', layer) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids, + decoder_positions, + deterministic: bool, + model_mode, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + ) -> jnp.ndarray: + """Sequentially apply all decoder layers. + + Args: + inputs: Input tensor + decoder_segment_ids: Segment IDs for attention masking + decoder_positions: Position indices + deterministic: Whether to use deterministic mode (no dropout) + model_mode: 'train', 'eval', etc. + slot: Optional slot index for paged attention + page_state: Optional page state for paged attention + + Returns: + Output tensor after all layers, or (output, None) if scan_layers is True + """ + # Iterate over layer attributes (layer_0, layer_1, ...) + for lyr in range(self.num_decoder_layers): + layer = getattr(self, f'layer_{lyr}') + outputs = layer( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=slot, + page_state=page_state, + ) + if self.config.scan_layers: + inputs = outputs[0] + else: + inputs = outputs + + if self.config.scan_layers: + return inputs, None + else: + return inputs + + class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" @@ -273,10 +399,14 @@ def setup(self): self.decoder_layer = self.get_decoder_layers() self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) if self.config.using_pipeline_parallelism: - pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) + nnx_decoder_classes = self.get_nnx_decoder_layers() + if nnx_decoder_classes is not None: + pipeline_stage_module = self.get_pipeline_stage_module(nnx_decoder_classes, use_nnx=True) + else: + pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer, use_nnx=False) remat_policy = self.get_remat_policy() - self.pipeline_module = pipeline.Pipeline( - config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy + self.pipeline_module = pipeline.create_pipeline( + config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy, use_nnx=(nnx_decoder_classes is not None) ) def minimal_policy(self, with_context=False): @@ -302,13 +432,11 @@ def get_remat_policy(self): cfg = self.config if cfg.remat_policy != "none": if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): - # save all if cfg.remat_policy == "minimal_flash": max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") policy = self.minimal_policy(with_context=True) elif cfg.remat_policy == "minimal": - # save all except context policy = self.minimal_policy() elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( @@ -351,7 +479,6 @@ def get_remat_policy(self): offload_dst="pinned_host", ) elif cfg.remat_policy == "minimal_offloaded": - # offload all except context policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[ @@ -431,30 +558,78 @@ def get_decoder_layers(self): # Default case to handle any unknown decoder block types. raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + def get_nnx_decoder_layers(self): + """Retrieves pure NNX decoder layer classes (without Linen wrappers) for pipeline parallelism. + + Returns: + A list containing one or more NNX Module classes for the decoder. + """ + match self.config.decoder_block: + case DecoderBlockType.DEFAULT: + return None + case DecoderBlockType.LLAMA2: + return [llama2.LlamaDecoderLayer] + case DecoderBlockType.MISTRAL: + return [mistral.MistralDecoderLayer] if hasattr(mistral, 'MistralDecoderLayer') else None + case DecoderBlockType.MIXTRAL: + return [mixtral.MixtralDecoderLayer] if hasattr(mixtral, 'MixtralDecoderLayer') else None + case DecoderBlockType.DEEPSEEK: + if self.config.use_batch_split_schedule: + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + else: + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] + case DecoderBlockType.GEMMA: + return [gemma.GemmaDecoderLayer] if hasattr(gemma, 'GemmaDecoderLayer') else None + case DecoderBlockType.GEMMA2: + return [gemma2.Gemma2DecoderLayer] if hasattr(gemma2, 'Gemma2DecoderLayer') else None + case DecoderBlockType.GEMMA3: + return [gemma3.Gemma3DecoderLayer] if hasattr(gemma3, 'Gemma3DecoderLayer') else None + case DecoderBlockType.GPT3: + return [gpt3.Gpt3DecoderLayer] + case DecoderBlockType.GPT_OSS: + if self.config.scan_layers: + return [gpt_oss.GptOssScannableBlock] if hasattr(gpt_oss, 'GptOssScannableBlock') else None + else: + return [gpt_oss.GptOssDecoderLayer] if hasattr(gpt_oss, 'GptOssDecoderLayer') else None + case DecoderBlockType.QWEN3: + return [qwen3.Qwen3DecoderLayer] if hasattr(qwen3, 'Qwen3DecoderLayer') else None + case DecoderBlockType.QWEN3_MOE: + return [qwen3.Qwen3MoeDecoderLayer] if hasattr(qwen3, 'Qwen3MoeDecoderLayer') else None + case DecoderBlockType.QWEN3_NEXT: + if self.config.scan_layers: + return [qwen3.Qwen3NextScannableBlock] if hasattr(qwen3, 'Qwen3NextScannableBlock') else None + else: + return [qwen3.Qwen3NextDecoderLayer] if hasattr(qwen3, 'Qwen3NextDecoderLayer') else None + case DecoderBlockType.SIMPLE: + return [simple_layer.SimpleDecoderLayer] + case DecoderBlockType.SIMPLE_MLP: + return [simple_layer.SimpleMlpDecoderLayer] + case DecoderBlockType.LLAMA4: + if self.config.scan_layers: + return [llama4.Llama4ScannableBlock] if hasattr(llama4, 'Llama4ScannableBlock') else None + else: + return [llama4.Llama4DecoderLayer] if hasattr(llama4, 'Llama4DecoderLayer') else None + case _: + return None + def set_remat_policy(self, block_layers, policy): """Set remat policy""" RemattedBlockLayers = [] for block_layer in block_layers: if self.config.parameter_memory_host_offload: - # Define parameter movement with mesh-based sharding def move_to_device(variables): """Move parameters to device with proper sharding.""" - def map_fn(path, value): - max_logging.log(f"models.py: Moving parameter {path} to device") return jax.device_put(value, max_utils.device_space()) - return jax.tree_util.tree_map_with_path(map_fn, variables) - # Transform layer class before remat block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) - # Apply remat policy to layer layer = nn.remat( block_layer, prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), policy=policy, - static_argnums=(4, 5), # Deterministic and model mode are static arguments. + static_argnums=(4, 5), ) RemattedBlockLayers.append(layer) return RemattedBlockLayers @@ -510,17 +685,38 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args ) - def get_pipeline_stage_module(self, decoder_blocks): - """get pipeline stage module""" + def get_pipeline_stage_module(self, decoder_blocks, use_nnx=False): + """get pipeline stage module + + Args: + decoder_blocks: List of decoder layer classes (either Linen or NNX) + use_nnx: If True, decoder_blocks are NNX classes and should be passed to Pipeline + without instantiation. Pipeline will handle NNX instantiation with proper rngs. + """ def get_layer_to_pipeline(blocks, cfg): if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - return blocks[1] # return the sparse block + return blocks[1] else: return blocks[0] cfg = self.config base_stage = get_layer_to_pipeline(decoder_blocks, cfg) + + if use_nnx: + if cfg.num_layers_per_pipeline_stage == 1: + return base_stage + else: + return lambda config, mesh, model_mode, rngs, quant=None: SequentialNNXWrapper( + decoder_layer_class=base_stage, + num_decoder_layers=cfg.num_layers_per_pipeline_stage, + config=config, + mesh=mesh, + model_mode=model_mode, + rngs=rngs, + quant=quant + ) + if cfg.set_remat_policy_on_layers_per_stage: policy = self.get_remat_policy() base_stage = self.set_remat_policy([base_stage], policy)[0] @@ -563,7 +759,6 @@ def _apply_embedding( y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) - # Merge the image embeddings with the text embeddings for multimodal models if image_embeddings is not None and cfg.use_multimodal: if cfg.model_name in [ "gemma3-4b", @@ -644,9 +839,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi ), ) - # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. if isinstance(shared_embedding, nnx.Module): embedding_table = shared_embedding.embedding.value else: @@ -657,7 +850,6 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) if self.config.normalize_embedding_logits: - # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) if cfg.final_logits_soft_cap: logits = logits / cfg.final_logits_soft_cap @@ -719,7 +911,6 @@ def __call__( policy = self.get_remat_policy() RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) - # scan does not support kwargs in layer call, passing broadcast_args as positional arg broadcast_args = ( decoder_segment_ids, decoder_positions, @@ -732,7 +923,7 @@ def __call__( y, decoder_segment_ids, decoder_positions, deterministic, model_mode ) else: - partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. + partition_spec = None if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] @@ -740,7 +931,6 @@ def __call__( num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - # We chose not to pipeline the dense layers, only sparse for SPMD. with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): y, _ = self.scan_decoder_layers( cfg, @@ -762,7 +952,7 @@ def __call__( model_mode=model_mode, )(y, *broadcast_args) y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - else: # Not DeepSeek + else: y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers if remaining_layers > 0: @@ -850,7 +1040,6 @@ def __call__( layer_prefixes = ["dense_layers", "moe_layers"] num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] - # Iterate over the two layer groups (dense and MoE) and apply layer transformation for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): for index in range(num_layers): kv_cache = kv_caches[index] if kv_caches is not None else None @@ -876,7 +1065,6 @@ def __call__( layer_kwargs = {} layer_call_kwargs = {} if cfg.decoder_block == DecoderBlockType.GEMMA3: - # Gemma3 uses both global and sliding window attention depending on the layer index. layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} if cfg.decoder_block == DecoderBlockType.LLAMA4: @@ -910,19 +1098,14 @@ def __call__( assert isinstance(y, jax.Array) - # After the final transformer layer, `y` holds the raw, un-normalized hidden state. hidden_state = y - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory - # Instead, we keep track on the hidden states, which has smaller size compared to full logits if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow("intermediates", "hidden_states", hidden_state) else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) - # The API of the Decoder is now a tuple, providing both the main output - # and the raw hidden state needed for auxiliary tasks. return logits, hidden_state, kv_caches def _apply_gemma3_scanned_blocks( @@ -942,7 +1125,6 @@ def _apply_gemma3_scanned_blocks( cfg = self.config mesh = self.mesh - # Define the repeating pattern length and calculate how many full blocks to scan attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) scan_length = cfg.num_decoder_layers // attention_pattern_length @@ -952,7 +1134,6 @@ def _apply_gemma3_scanned_blocks( layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} layer_kwargs = {"num_of_layers": attention_pattern_length} - # Apply the main scan over the full blocks if scan_length > 0: broadcast_args = ( decoder_segment_ids, @@ -971,10 +1152,8 @@ def _apply_gemma3_scanned_blocks( **layer_kwargs, )(y, *broadcast_args, **layer_call_kwargs) - # Apply any remaining layers that did not fit into a full scanned block num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length if num_remaining_layers > 0: - # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions rem_layer_kwargs = {"num_of_layers": num_remaining_layers} layer = RemattedGemma3Block( config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs diff --git a/src/MaxText/layers/decoders_linen.py b/src/MaxText/layers/decoders_linen.py new file mode 100644 index 000000000..aba50a8d6 --- /dev/null +++ b/src/MaxText/layers/decoders_linen.py @@ -0,0 +1,993 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""""Module for decoder layers.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Any +import functools + +import jax +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh, NamedSharding + +from flax import linen as nn +from flax import nnx +from flax.linen.partitioning import ScanIn + +from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT +from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from MaxText import max_logging +from MaxText import max_utils +from MaxText.inference import page_manager +from MaxText.layers import linears +from MaxText.layers import quantizations +from MaxText.layers import pipeline +from MaxText import maxtext_utils +from MaxText import multimodal_utils +from MaxText import sharding +from MaxText.layers.attentions import attention_as_linen +from MaxText.layers.normalizations import rms_norm +from MaxText.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen +from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.layers import ( + deepseek, + deepseek_batchsplit, + gemma, + gemma2, + gemma3, + gpt3, + gpt_oss, + llama2, + llama4, + mistral, + mixtral, + qwen3, + simple_layer, +) + +# ------------------------------------------------------------------------------ +# The network: Decoder Definitions +# ------------------------------------------------------------------------------ + + +class DecoderLayer(nn.Module): + """ + Transformer decoder layer that attends to the encoder. + This is the core, reusable building block for both the main model's + decoder stack and the auxiliary MTP layers. + """ + + config: Config + mesh: Mesh + model_mode: str + quant: None | Quant = None + + @nn.compact + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + cfg = self.config + mesh = self.mesh + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + ) + + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") + else: + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + + if model_mode == MODEL_MODE_PREFILL: + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + else: + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + + inputs = checkpoint_name(inputs, "decoder_layer_input") + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] + lnx = rms_norm( + num_features=inputs.shape[-1], + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + )(inputs) + if model_mode == MODEL_MODE_PREFILL: + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + else: + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_layer = attention_as_linen( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + model_mode=model_mode, + ) + + attention_lnx, kv_cache = attention_layer( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + if model_mode == MODEL_MODE_PREFILL: + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + else: + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + + # MLP block. + mlp_lnx = linears.mlp_block( + in_features=lnx.shape[-1], + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + model_mode=model_mode, + config=cfg, + quant=self.quant, + mesh=self.mesh, + )(lnx, deterministic=deterministic) + if model_mode == MODEL_MODE_PREFILL: + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + else: + mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + + next_layer_addition = mlp_lnx + attention_lnx + + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition, deterministic=deterministic + ) + + layer_output = next_layer_addition_dropped_out + inputs + if model_mode == MODEL_MODE_PREFILL: + layer_output = _maybe_shard_with_logical( + layer_output, + logical_axis_names, + ) + else: + layer_output = _maybe_shard_with_logical( + layer_output, + logical_axis_names, + ) + + if cfg.record_internal_nn_metrics: + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow( + "intermediates", + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if cfg.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +class SequentialBlockDecoderLayers(nn.Module): + """Sequential unscanned series of decoder layers.""" + + decoder_layer: Any + num_decoder_layers: int + config: Config + mesh: Mesh + quant: Quant + model_mode: str + + @nn.compact + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids, + decoder_positions, + deterministic: bool, + model_mode, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + ) -> jnp.ndarray: + for lyr in range(self.num_decoder_layers): + inputs = self.decoder_layer( + config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode + )( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=slot, + page_state=page_state, + ) + if self.config.scan_layers: + inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). + if self.config.scan_layers: + return inputs, None # pytype: disable=bad-return-type + else: + return inputs + + +class Decoder(nn.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture.""" + + config: Config + mesh: Mesh + quant: None | Quant = None + model_mode: str = MODEL_MODE_TRAIN + + def setup(self): + """Initialize decoder layer.""" + self.decoder_layer = self.get_decoder_layers() + self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) + if self.config.using_pipeline_parallelism: + pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) + remat_policy = self.get_remat_policy() + self.pipeline_module = pipeline.Pipeline( + config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy + ) + + def minimal_policy(self, with_context=False): + """Helper for creating minimal checkpoint policies.""" + names = [ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ] + if with_context: + names.append("context") + return jax.checkpoint_policies.save_only_these_names(*names) + + def get_remat_policy(self): + """Get remat policy""" + policy = None + cfg = self.config + if cfg.remat_policy != "none": + if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): + # save all + if cfg.remat_policy == "minimal_flash": + max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") + max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") + policy = self.minimal_policy(with_context=True) + elif cfg.remat_policy == "minimal": + # save all except context + policy = self.minimal_policy() + elif cfg.remat_policy == "save_dot_with_context_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "context", + "out_proj", + ) + elif cfg.remat_policy == "save_dot_except_mlpwi": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", + ) + elif cfg.remat_policy == "save_dot_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + ) + elif cfg.remat_policy == "save_qkv_proj": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + ) + elif cfg.remat_policy == "qkv_proj_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": + # offload all except context + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=[ + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwi_0", + "mlpwi_1", + "mlpwi", + "mlpwo", + ], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "custom": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=cfg.tensors_on_device, + names_which_can_be_offloaded=cfg.tensors_to_offload, + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "save_out_proj": + policy = jax.checkpoint_policies.save_only_these_names( + "out_proj", + ) + else: + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" + policy = None + return policy + + def get_decoder_layers(self): + """Retrieves a list of decoder layer classes based on the `decoder_block` config. + + Returns: + A list containing one or more `nn.Module` classes for the decoder. + """ + match self.config.decoder_block: + case DecoderBlockType.DEFAULT: + return [DecoderLayer] + case DecoderBlockType.LLAMA2: + return [llama2.LlamaDecoderLayerToLinen] + case DecoderBlockType.MISTRAL: + # TODO(ranran): update to Mistral with sliding window attention + return [mistral.MistralDecoderLayerToLinen] + case DecoderBlockType.MIXTRAL: + return [mixtral.MixtralDecoderLayerToLinen] + case DecoderBlockType.DEEPSEEK: + if self.config.use_batch_split_schedule: + return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + else: + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] + case DecoderBlockType.GEMMA: + return [gemma.GemmaDecoderLayerToLinen] + case DecoderBlockType.GEMMA2: + return [gemma2.Gemma2DecoderLayerToLinen] + case DecoderBlockType.GEMMA3: + return [gemma3.Gemma3DecoderLayerToLinen] + case DecoderBlockType.GPT3: + return [gpt3.Gpt3DecoderLayer] + case DecoderBlockType.GPT_OSS: + return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] + case DecoderBlockType.QWEN3: + return [qwen3.Qwen3DecoderLayerToLinen] + case DecoderBlockType.QWEN3_MOE: + return [qwen3.Qwen3MoeDecoderLayerToLinen] + case DecoderBlockType.QWEN3_NEXT: + return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen] + case DecoderBlockType.SIMPLE: + return [simple_layer.SimpleDecoderLayerToLinen] + case DecoderBlockType.SIMPLE_MLP: + return [simple_layer.SimpleMlpDecoderLayerToLinen] + case DecoderBlockType.LLAMA4: + return [llama4.Llama4ScannableBlockToLinen] if self.config.scan_layers else [llama4.Llama4DecoderLayerToLinen] + case _: + # Default case to handle any unknown decoder block types. + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def set_remat_policy(self, block_layers, policy): + """Set remat policy""" + RemattedBlockLayers = [] + for block_layer in block_layers: + if self.config.parameter_memory_host_offload: + # Define parameter movement with mesh-based sharding + def move_to_device(variables): + """Move parameters to device with proper sharding.""" + + def map_fn(path, value): + max_logging.log(f"models.py: Moving parameter {path} to device") + return jax.device_put(value, max_utils.device_space()) + + return jax.tree_util.tree_map_with_path(map_fn, variables) + + # Transform layer class before remat + block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) + + # Apply remat policy to layer + layer = nn.remat( + block_layer, + prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), + policy=policy, + static_argnums=(4, 5), # Deterministic and model mode are static arguments. + ) + RemattedBlockLayers.append(layer) + return RemattedBlockLayers + + def get_norm_layer(self, num_features: int): + """get normalization layer (return type inherits from nn.Module)""" + if self.config.decoder_block in ( + DecoderBlockType.DEFAULT, + DecoderBlockType.LLAMA2, + DecoderBlockType.MISTRAL, + DecoderBlockType.MIXTRAL, + DecoderBlockType.DEEPSEEK, + DecoderBlockType.GEMMA, + DecoderBlockType.GEMMA2, + DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN3, + DecoderBlockType.QWEN3_MOE, + DecoderBlockType.QWEN3_NEXT, + DecoderBlockType.GPT_OSS, + DecoderBlockType.SIMPLE, + DecoderBlockType.SIMPLE_MLP, + DecoderBlockType.LLAMA4, + ): + return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) + elif self.config.decoder_block == DecoderBlockType.GPT3: + return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): + """scan decoder layers, calls `flax.linen.transforms.scan`""" + initializing = self.is_mutable_collection("params") + params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) + cache_spec = 0 + scan_fn = nn.scan( + decoder_layer, + variable_axes={ + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={ + "params": True, + "dropout": cfg.enable_dropout, + }, + in_axes=in_axes_tuple, + length=length, + metadata_params={nn.PARTITION_NAME: metadata_axis_name}, + ) + return scan_fn( + config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args + ) + + def get_pipeline_stage_module(self, decoder_blocks): + """get pipeline stage module""" + + def get_layer_to_pipeline(blocks, cfg): + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + return blocks[1] # return the sparse block + else: + return blocks[0] + + cfg = self.config + base_stage = get_layer_to_pipeline(decoder_blocks, cfg) + if cfg.set_remat_policy_on_layers_per_stage: + policy = self.get_remat_policy() + base_stage = self.set_remat_policy([base_stage], policy)[0] + if cfg.num_layers_per_pipeline_stage == 1: + stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) + elif cfg.scan_layers_per_stage: + stage_module = self.scan_decoder_layers( + cfg, + base_stage, + cfg.num_layers_per_pipeline_stage, + "layers_per_stage", + self.mesh, + in_axes_tuple=(nn.broadcast,) * 4, + ) + else: + stage_module = SequentialBlockDecoderLayers( + decoder_layer=base_stage, + num_decoder_layers=cfg.num_layers_per_pipeline_stage, + config=cfg, + mesh=self.mesh, + quant=self.quant, + model_mode=self.model_mode, + ) + return stage_module + + @nn.compact + def _apply_embedding( + self, + shared_embedding: nn.Module | nnx.Module, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings=None, + bidirectional_mask=None, + image_masks=None, + ): + """Applies token and positional embeddings to the input tokens.""" + cfg = self.config + + y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) + + # Merge the image embeddings with the text embeddings for multimodal models + if image_embeddings is not None and cfg.use_multimodal: + if cfg.model_name in [ + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "llama4-17b-16e", + "llama4-17b-128e", + "qwen3-omni-30b-a3b", + ]: + y = multimodal_utils.merge_mm_embeddings( + text_embeddings=y, + vision_embeddings=image_embeddings, + mask=bidirectional_mask, + image_masks=image_masks, + ) + # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed + else: + raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") + + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if cfg.use_untrainable_positional_embedding: + y = positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y, decoder_positions) + + if cfg.trainable_position_size > 0: + y += embed_as_linen( + num_embeddings=cfg.trainable_position_size, + num_features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name="position_embedder", + config=cfg, + mesh=self.mesh, + )(decoder_positions, model_mode=model_mode) + return y + + @nn.compact + def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): + """Applies final normalization and projects hidden states to logits.""" + + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + norm_out_sharding = NamedSharding( + self.mesh, + nn.logical_to_mesh_axes( + ( + "activation_batch", + "activation_length_no_exp", + "activation_embed", + ) + ), + ) + else: + norm_out_sharding = None + + y = self.get_norm_layer(num_features=y.shape[-1])( + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="decoder_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=cfg.parameter_memory_host_offload, + )(y, out_sharding=norm_out_sharding) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes((None, None, "activation_vocab"))) + else: + out_sharding = NamedSharding( + self.mesh, + nn.logical_to_mesh_axes( + ( + "activation_embed_and_logits_batch", + "activation_length_no_exp", + "activation_vocab", + ) + ), + ) + + # [batch, length, emb_dim] -> [batch, length, vocab_size] + if cfg.logits_via_embedding: + # Use the transpose of embedding matrix for logit transform. + if isinstance(shared_embedding, nnx.Module): + embedding_table = shared_embedding.embedding.value + else: + embedding_table = shared_embedding.variables["params"]["embedding"] + if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): + embedding_table = embedding_table.unbox() + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = logits / cfg.final_logits_soft_cap + logits = jnp.tanh(logits) * cfg.final_logits_soft_cap + else: + logits = linears.dense_general( + inputs_shape=y.shape, + out_features_shape=cfg.vocab_size, + weight_dtype=cfg.weight_dtype, + dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + kernel_axes=("embed", "vocab"), + shard_mode=cfg.shard_mode, + name="logits_dense", + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=cfg.parameter_memory_host_offload, + )( + y, + out_sharding=out_sharding, + ) # We do not quantize the logits matmul. + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + + return logits + + @nn.compact + def __call__( + self, + shared_embedding: nn.Module | nnx.Module, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + ): + cfg = self.config + mesh = self.mesh + assert decoder_input_tokens.ndim == 2 # [batch, len] + + # [batch, length] -> [batch, length, emb_dim] + y = self._apply_embedding( + shared_embedding, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + ) + + policy = self.get_remat_policy() + RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) + # scan does not support kwargs in layer call, passing broadcast_args as positional arg + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + if cfg.using_pipeline_parallelism: + if cfg.pipeline_fsdp_ag_once: + partition_spec = self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + else: + partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." + dense_layer = RemattedBlockLayers[0] + moe_layer = RemattedBlockLayers[1] + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) + # We chose not to pipeline the dense layers, only sparse for SPMD. + with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + y, _ = self.scan_decoder_layers( + cfg, + dense_layer, + cfg.first_num_dense_layers, + "dense_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=model_mode, + )(y, *broadcast_args) + if num_moe_layers_outside_pp > 0: + y, _ = self.scan_decoder_layers( + cfg, + moe_layer, + num_moe_layers_outside_pp, + "moe_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=model_mode, + )(y, *broadcast_args) + y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + else: # Not DeepSeek + y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers + if remaining_layers > 0: + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) + with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + y, _ = self.scan_decoder_layers( + cfg, + RemattedBlockLayers[0], + remaining_layers, + "layers_outside_pipeline", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=model_mode, + )(y, *broadcast_args) + else: + if cfg.scan_layers: + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." + layer_call_kwargs = { + "page_state": page_state, + "previous_chunk": previous_chunk, + "slot": slot, + } + dense_layer = RemattedBlockLayers[0] + dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) + y, _ = self.scan_decoder_layers( + cfg, + dense_layer, + cfg.first_num_dense_layers, + "dense_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=model_mode, + )(y, *broadcast_args) + moe_layer = RemattedBlockLayers[1] + moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, _ = self.scan_decoder_layers( + cfg, + moe_layer, + num_moe_layers, + "moe_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=model_mode, + )(y, *broadcast_args) + elif cfg.decoder_block == DecoderBlockType.GEMMA3: + y = self._apply_gemma3_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) + else: + RemattedBlockLayer = RemattedBlockLayers[0] + scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) + layer_kwargs = {} + if cfg.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "nope_layer_interval": self.config.nope_layer_interval, + "interleave_moe_layer_step": self.config.interleave_moe_layer_step, + } + y, _ = self.scan_decoder_layers( + cfg, + RemattedBlockLayer, + scan_length, + "layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=model_mode, + **layer_kwargs, + )(y, *broadcast_args) + else: + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." + dense_layer = RemattedBlockLayers[0] + moe_layer = RemattedBlockLayers[1] + + layers = [dense_layer, moe_layer] + layer_prefixes = ["dense_layers", "moe_layers"] + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] + # Iterate over the two layer groups (dense and MoE) and apply layer transformation + for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): + for index in range(num_layers): + kv_cache = kv_caches[index] if kv_caches is not None else None + y, kv_cache = layer( + config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode + )( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + if kv_caches is not None and kv_cache is not None: + kv_caches[index] = kv_cache + else: + for lyr in range(cfg.num_decoder_layers): + RemattedBlockLayer = RemattedBlockLayers[0] + layer_kwargs = {} + layer_call_kwargs = {} + if cfg.decoder_block == DecoderBlockType.GEMMA3: + # Gemma3 uses both global and sliding window attention depending on the layer index. + layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} + layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} + if cfg.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), + "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + } + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + layer_kwargs = {"layer_idx": lyr} + if cfg.decoder_block == DecoderBlockType.GPT_OSS: + layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} + layer = RemattedBlockLayer( + config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs + ) + kv_cache = kv_caches[lyr] if kv_caches is not None else None + y, kv_cache = layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + **layer_call_kwargs, + ) + if kv_caches is not None and kv_cache is not None: + kv_caches[lyr] = kv_cache + + assert isinstance(y, jax.Array) + + # After the final transformer layer, `y` holds the raw, un-normalized hidden state. + hidden_state = y + + # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # Instead, we keep track on the hidden states, which has smaller size compared to full logits + if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + logits = None + self.sow("intermediates", "hidden_states", hidden_state) + else: + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + # The API of the Decoder is now a tuple, providing both the main output + # and the raw hidden state needed for auxiliary tasks. + return logits, hidden_state, kv_caches + + def _apply_gemma3_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ): + """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" + + cfg = self.config + mesh = self.mesh + + # Define the repeating pattern length and calculate how many full blocks to scan + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = cfg.num_decoder_layers // attention_pattern_length + + policy = self.get_remat_policy() + RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlockToLinen], policy)[0] + + layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} + layer_kwargs = {"num_of_layers": attention_pattern_length} + + # Apply the main scan over the full blocks + if scan_length > 0: + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + y, _ = self.scan_decoder_layers( + cfg, + RemattedGemma3Block, + scan_length, + "layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + model_mode=self.model_mode, + **layer_kwargs, + )(y, *broadcast_args, **layer_call_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length + if num_remaining_layers > 0: + # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + layer = RemattedGemma3Block( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs + ) # pytype: disable=wrong-keyword-args + y, _ = layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + **layer_call_kwargs, + ) + return y diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index c7284fb22..20c9dbe51 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -15,7 +15,7 @@ """ Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """ import functools -from typing import Any +from typing import Any, Callable import numpy as np @@ -26,13 +26,16 @@ from flax.core import meta from flax import linen as nn +from flax import nnx from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT from MaxText.sharding import all_gather_over_fsdp +from MaxText import max_logging +from MaxText.layers import nnx_wrappers -class Pipeline(nn.Module): - """Module that implements pipelining across stages. +class Pipeline(nnx.Module): + """NNX Module that implements pipelining across stages. This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. This will produce a pipeline pattern if the stage dimension is sharded. @@ -42,18 +45,34 @@ class Pipeline(nn.Module): Attributes: config: Importantly contains num_pipeline_microbatches, num_pipeline_repeats. - layers: A module instance that each stage can execute. It can either be a single layer such as a + layers: A callable (NNX class or Linen class) that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. mesh: The device mesh of the system. remat_policy: Remat policy to use for the loop iterations """ - config: Config - layers: nn.Module # The name of this property (layers) is reflected in the state pytree and thus also checkpoints. - mesh: Mesh - remat_policy: Any = None + def __init__( + self, + layers: Callable | type, + config: Config, + mesh: Mesh, + rngs: nnx.Rngs = None, + remat_policy: Any = None, + ): + """Initialize Pipeline with NNX or Linen decoder layers. + + Args: + layers: Either an NNX class (type) or Linen class (type) to instantiate for each stage + config: Model configuration + mesh: Device mesh for sharding + rngs: Optional NNX RNG state (passed by ToLinen wrapper) + remat_policy: Remat policy for loop iterations + """ + self.config = config + self.mesh = mesh + self.rngs = rngs + self.remat_policy = remat_policy - def setup(self): # pylint: disable=missing-function-docstring self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches @@ -68,6 +87,29 @@ def setup(self): # pylint: disable=missing-function-docstring self.batch_axis_name = "activation_batch" self.seq_len_axis_name = "activation_length_no_exp" + # Detect if layers is a Linen class/instance or NNX class + self._is_linen = (isinstance(layers, type) and issubclass(layers, nn.Module)) or isinstance(layers, nn.Module) + + if self._is_linen: + if isinstance(layers, nn.Module): + self.layers = layers + else: + self.layers = layers(config=config, mesh=mesh, model_mode=MODEL_MODE_TRAIN) + self._linen_variables = None + else: + # Create num_stages independent NNX instances, stored as attributes for + # NNX pytree tracking (not as Python lists). + for s in range(self.num_stages): + stage_rngs = nnx.Rngs(s) + instance = layers( + config=config, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + rngs=stage_rngs, + quant=None, + ) + setattr(self, f'stage_{s}', instance) + def need_circ_storage(self): return ( self.config.num_pipeline_repeats > 1 @@ -75,83 +117,56 @@ def need_circ_storage(self): ) def iterations_to_complete_first_microbatch_one_repeat(self): - # Return the number of iterations it takes for microbatch 0 to finish a repeat + """Returns iterations for microbatch 0 to complete one repeat.""" return self.forwarding_delay * (self.num_stages - 1) def iterations_to_complete_first_microbatch(self): - # Return the number of iterations it takes for microbatch 0 to finish the last stage of the last repeat + """Returns iterations for microbatch 0 to complete all repeats.""" return ( self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + self.iterations_to_complete_first_microbatch_one_repeat() ) def init_states(self, inputs): - """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover - Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] - - Returns a dictionary with properties - shift: zeros shape [num_stages, micro_size, sequence, embed] - prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None - state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] - circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None - circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None - loop_iteration: scalar set initially to 0. - """ + """Initialize pipeline loop state buffers. - # Shift is used to rotate the output of each pipeline into the input of the next - # shift has shape [num_stages, micro_size, sequence, embed] - shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + Assumes inputs are reshaped to [num_microbatches, micro_batch_size, sequence, embed]. - shift = nn.with_logical_constraint( + Returns: + Dictionary containing: + - shift: Buffer for rotating outputs [num_stages, micro_size, sequence, embed] + - prev_outputs: Same shape as shift (only used with pipeline_delay_activation_forwarding) + - state_io: Input/output buffer [num_stages, microbatches/stages, micro_size, sequence, embed] + - circ_storage: Circular storage buffer (only when num_microbatches > num_stages) + - circ_storage_mover: One-iteration delay buffer for circ_storage + - loop_iteration: Iteration counter (starts at 0) + """ + shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + shift = self._with_logical_constraint( shift, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, ) - # Prev outputs has the same shape of the output (and shift) if self.config.pipeline_delay_activation_forwarding: prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - prev_outputs = nn.with_logical_constraint( + prev_outputs = self._with_logical_constraint( prev_outputs, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, ) else: prev_outputs = None - # state_io (state input output) at first holds all of the input batches, but also will hold the outputs - # as the pipeline runs/finishes - # state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed] state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) - # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. - state_io = nn.with_logical_constraint( + state_io = self._with_logical_constraint( state_io, ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, ) - # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only - # needed when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without - # additional storage. - # circ_storage has shape [num_stages, microbatches, micro_size, sequence, embed]. - # Note that this shape is a factor of num_stages larger than necessary - each stage holds the global batch, but only - # stage 0 holds the real activations (since it will use them), the rest hold dummy ones. This amount of storage - # [global_batch, sequence, embed] is fine as long as there is some amount of additional sharding axes, e.g. FSDP, - # TP, DP (e.g. there are many devices that shard stage 0) - # We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101). if self.use_circ_storage: circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) - else: - circ_storage = None - - # circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration - # of delay circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] - if self.use_circ_storage: circ_storage_mover = shift else: + circ_storage = None circ_storage_mover = None init_loop_state = { @@ -164,55 +179,44 @@ def init_states(self, inputs): } return init_loop_state + def _with_logical_constraint(self, tensor, logical_axis_names): + """Applies logical sharding constraints to tensor.""" + return nn.with_logical_constraint( + tensor, + logical_axis_names, + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) + def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): - """ - Construct stages_in: the global array that is operated on for this iteration, shape same as - shift=[stages, micro_size, sequence, embed] - This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from - state_io or an old one from circ_storage - """ + """Constructs input array for all stages for this iteration. - # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) + Returns array of shape [stages, micro_size, sequence, embed] with rotated outputs + from previous iteration, except stage 0 which gets new input from state_io or circ_storage. + """ state_io_batch_idx = loop_iteration % self.microbatches_per_stage state_io_slice = state_io[:, state_io_batch_idx] if self.use_circ_storage: - # Setup potential input from circ_storage, which also has a rotating index for microbatch, - # size of num_microbatches circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches circular_stage_in = circ_storage[:, circ_storage_batch_idx] else: - # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage circular_stage_in = shift - # For early loop iterations we grab a new input for stage 0 from the state_io. Once each microbatch has left - # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. - # from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) - # Note that first_stage_in may correspond to bubble computation during the last few iterations. - # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are - # thus discarded / not returned. - # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The - # state_io will not contain bubble results at the end of the last iteration. - def select_state_or_input(first_stage_in, shift): - # Selects input for stage 0, shift for other stages return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) - # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) - stages_in = nn.with_logical_constraint( + stages_in = self._with_logical_constraint( stages_in, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, ) return stages_in def shard_dim_by_stages(self, x, dim: int): - # Shards a dimension by stages. Currently, the sharding of other dimensions are left up the compiler, alternatively - # we may want to copy over the sharding from the other input axes. + """Shards the specified dimension by stage.""" dims_mapping = [jax.sharding.PartitionSpec.UNCONSTRAINED] * x.ndim dims_mapping[dim] = "stage" dims_mapping = tuple(dims_mapping) @@ -220,28 +224,24 @@ def shard_dim_by_stages(self, x, dim: int): return jax.lax.with_sharding_constraint(x, sharding) def get_microbatch_and_repeat_ids(self, loop_iteration): - """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and - non-circular""" - # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages + """Gets microbatch and repeat IDs for all stages at this iteration.""" microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): - """Use vmap to implement a sharded parallel gather. - Parallel gather means each stage has its own weights, and gets one slice from it. + """Sharded parallel gather where each stage has its own weights and gets one slice. + Args: - weights: Per-stage data to be gathered from. - repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. - repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not - have this dimension. - stages_dim_in_weights: The dimension in weights that represents parallel stages. + weights: Per-stage data to gather from. + repeat_ids: Integer tensor of shape [num_stages] with repeat indices per stage. + repeat_dim_in_weights: Dimension where repeat_ids are applied (removed in output). + stages_dim_in_weights: Dimension representing parallel stages. + Returns: - The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights - removed. + Per-stage gathered values with repeat_dim_in_weights removed. """ - def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) @@ -255,21 +255,16 @@ def _gather_one(x, repeat_id): return stage_weights def vmap_gather(self, xs, ids, ids_dim): - """Use vmap to implement a stage-wise sharded gather. - - The stages share the same input, but they have different offsets. + """Stage-wise sharded gather with shared input but different offsets per stage. Args: - xs: Data shared by all stages, to be gathered from. - ids: Integer tensor of shape [num_stages], the offsets of the stages. - ids_dim: The dimension in xs where ids are applied. In the output, this - dimension will be [num_stages], since each stage gets one slice. + xs: Data shared by all stages. + ids: Integer tensor of shape [num_stages] with offsets per stage. + ids_dim: Dimension where ids are applied (output has [num_stages] size here). Returns: - The per-stage gathered values. The shape is xs.shape but with ids_dim size - replaced with [num_stages]. + Per-stage gathered values with ids_dim size replaced with [num_stages]. """ - def _gather_one(x, i): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) @@ -278,18 +273,11 @@ def _gather_one(x, i): return self.shard_dim_by_stages(outs, 0) def get_new_loop_state(self, output, loop_state): - """ - Update the various buffers given the output of the most recent iteration - * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) - * Pushing inputs up from top of state_io into first stage of shift - * Pulling outputs up from last stage of shift into bottom of state_io - * shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to - right/down - * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage - * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration - * prev_outputs: is set to the current output - """ + """Updates all pipeline buffers after one iteration. + Updates shift, state_io, circ_storage, circ_storage_mover, and prev_outputs + to advance the pipeline by one step. + """ old_state_io = loop_state["state_io"] old_circ_storage = loop_state["circ_storage"] old_circ_storage_mover = loop_state["circ_storage_mover"] @@ -297,25 +285,19 @@ def get_new_loop_state(self, output, loop_state): old_prev_outputs = loop_state["prev_outputs"] def _rotate_right(arr): - # Use lax.slice to avoid generating a gather. last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) return jnp.concatenate([last, except_last], axis=0) def _shift_right(arr): padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) - # Use lax.slice to guarantee the gradient is a pad. return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) - # Shift either rotates or shifts depending on if the last stage immediately must send to first or not - # For non-circular pipelines, the last stage does not need to send to first - # For circular pipelines with #micro = #stages, last stage immediately sends to first - # For circular pipelines with #micro > stages (circ_storage), last stage sends to circ storage def _update_shift(output_in): if self.config.num_pipeline_repeats == 1 or self.use_circ_storage: - return _shift_right(output_in) # last stage does not have to send to first immediately + return _shift_right(output_in) else: - return _rotate_right(output_in) # last stage must immediately send to first + return _rotate_right(output_in) if self.config.pipeline_delay_activation_forwarding: new_shift = _update_shift(old_prev_outputs) @@ -325,17 +307,12 @@ def _update_shift(output_in): new_prev_outputs = None if self.use_circ_storage: - # Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index. - # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped - # compute/async transfers def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): rotated = _rotate_right(circ_storage_mover_in) rotated = jnp.expand_dims(rotated, 1) - # We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0 offset = ( loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 - ) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the - # previous output - using circ_storage_mover before it is updated + ) % self.config.num_pipeline_microbatches return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) @@ -344,13 +321,10 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): new_circ_storage = None new_circ_storage_mover = None - # Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the - # last stage output stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] def _update_state_io(state_in, stream_slice, output): - # Shift the current slice to the left, then fill the last stage with the final output. padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) stream_slice = jnp.where( @@ -372,119 +346,132 @@ def _update_state_io(state_in, stream_slice, output): return new_loop_state def permute_output_micro_per_stage_dim(self, output): - # The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to - # state_io - it will land on a different index of state_io depending on the number of iterations. + """Permutes output to correct microbatch ordering after pipeline completion.""" microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage permutation = ( np.arange(self.microbatches_per_stage) + microbatch_0_idx - ) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear - # in idx 1, etc + ) % self.microbatches_per_stage output = output[:, permutation] return output - def get_current_stage_weights(self, pipeline_weights, loop_iteration): - """ - Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. - {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. - For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However - for circular pipelines each stage grabs only the weights corresponding to the current repeat. - """ - if self.config.num_pipeline_repeats > 1: - return self.get_current_repeat_from_stages(pipeline_weights, loop_iteration) - else: - return pipeline_weights - - def get_current_repeat_from_stages(self, weights, loop_iteration): - """get current repeat from stages""" - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - - def gather_weights_for_stages_in(weights): - return jax.tree.map( - functools.partial( - self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 - ), - weights, - ) + def _initialize_linen_parameters(self, sample_input, sample_seg_ids, sample_positions, deterministic, model_mode): + """Initialize Linen module parameters for all stages.""" + if self._linen_variables is not None: + return - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis( - weights, 0, circular_metadata_params - ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular - # entry per stage. - weights = gather_weights_for_stages_in(weights) - return weights - - def get_vmap_func_for_init(self): - """This vmap func is used to initialize the weights only on init.""" - - def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): - """nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance.""" - return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - - vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, None, None), - spmd_axis_name="stage", - variable_axes={"params": 0, "_overwrite_with_gradient": 0}, - split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, + linen_rngs = {'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(1)} + + base_params = self.layers.init( + linen_rngs, + sample_input, + sample_seg_ids, + sample_positions, + deterministic, + model_mode, ) - return vmap_func - def get_main_vmap_func_for_iterations(self): - """ - Returns main stage function vmapped by number of stages. - This becomes a vmap over a single layer instance if body_instance is a single layer, - else a set of layers if body_instance is a set of layers. - """ + stage_params = {} + for stage_idx in range(self.num_stages): + stage_params[f'stage_{stage_idx}'] = jax.tree_util.tree_map(lambda x: x, base_params) + + self._linen_variables = {'params': stage_params} + + def _run_stages_linen( + self, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ): + """Run stages using Linen module with manual vmap.""" + stage_params_list = [self._linen_variables['params'][f'stage_{i}'] for i in range(self.num_stages)] + stacked_params = jax.tree_util.tree_map( + lambda *xs: jnp.stack(xs, axis=0), + *stage_params_list + ) - def func_to_vmap( - body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode - ): - """nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance.""" - weights = meta.remove_axis( - weights, - 0, - { - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, + def apply_stage(stage_params, stage_input, stage_seg_ids, stage_pos): + output = self.layers.apply( + stage_params, + stage_input, + stage_seg_ids, + stage_pos, + deterministic, + model_mode, ) - return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - - vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, 0, None, None), - spmd_axis_name="stage", - variable_axes={"params": 0}, - split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, + if isinstance(output, tuple): + return output[0] + return output + + if stages_segment_ids is None: + vmapped_apply = jax.vmap( + lambda p, i, pos: apply_stage(p, i, None, pos), + in_axes=(0, 0, 0), + out_axes=0 + ) + stages_outputs = vmapped_apply(stacked_params, stages_inputs, stages_positions) + else: + vmapped_apply = jax.vmap(apply_stage, in_axes=(0, 0, 0, 0), out_axes=0) + stages_outputs = vmapped_apply(stacked_params, stages_inputs, stages_segment_ids, stages_positions) + + return stages_outputs + + def _run_stages_vmapped( + self, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ): + """Run all stages in parallel using JAX vmap over NNX instances.""" + stage_0 = getattr(self, 'stage_0') + graphdef, state_0 = nnx.split(stage_0) + + states = [state_0] + for s in range(1, self.num_stages): + instance = getattr(self, f'stage_{s}') + _, state_s = nnx.split(instance) + states.append(state_s) + + stacked_state = jax.tree_util.tree_map( + lambda *xs: jnp.stack(xs, axis=0), + *states ) - return vmap_func + + def call_stage(state, stage_input, stage_seg_ids, stage_pos): + module = nnx.merge(graphdef, state) + output = module(stage_input, stage_seg_ids, stage_pos, deterministic, model_mode) + if isinstance(output, tuple): + return output[0] + return output + + if stages_segment_ids is None: + def call_stage_no_seg(state, stage_input, stage_pos): + module = nnx.merge(graphdef, state) + output = module(stage_input, None, stage_pos, deterministic, model_mode) + if isinstance(output, tuple): + return output[0] + return output + + vmapped_call = jax.vmap(call_stage_no_seg, in_axes=(0, 0, 0), out_axes=0) + stages_outputs = vmapped_call(stacked_state, stages_inputs, stages_positions) + else: + vmapped_call = jax.vmap(call_stage, in_axes=(0, 0, 0, 0), out_axes=0) + stages_outputs = vmapped_call(stacked_state, stages_inputs, stages_segment_ids, stages_positions) + + return stages_outputs def run_one_iteration( - self, loop_state, pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance + self, + loop_state, + positions, + segment_ids, + deterministic, + model_mode, ): - """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, - and update the loop state.""" + """Run one loop iteration: get inputs, execute stages, update state.""" state_io = loop_state["state_io"] shift = loop_state["shift"] circ_storage = loop_state["circ_storage"] @@ -493,67 +480,32 @@ def run_one_iteration( microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) - # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire - # buffer. stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None - vmap_func = self.get_main_vmap_func_for_iterations() - - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - - def prepare_vars_for_main_vmap(weights): - def gather_weights_for_stages_in(weights): - return jax.tree.map( - functools.partial( - self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 - ), - weights, - ) - - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis( - weights, 0, circular_metadata_params - ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one - # circular entry per stage. - weights = gather_weights_for_stages_in(weights) - return weights - - vmap_func = nn.map_variables( - vmap_func, - mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], - mutable=True, - trans_in_fn=prepare_vars_for_main_vmap, + if self._is_linen: + stages_output = self._run_stages_linen( + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ) + else: + stages_output = self._run_stages_vmapped( + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, ) - - stage_weights = self.get_current_stage_weights(pipeline_weights, loop_iteration) - stages_output = vmap_func( - decoder_layer_instance, - stage_weights, - stages_inputs, - stages_segment_ids, - stages_positions, - deterministic, - model_mode, - ) - if self.config.scan_layers: - stages_output = stages_output[0] new_state = self.get_new_loop_state(stages_output, loop_state) return new_state def get_pipeline_remat_policy(self): - """Returns the pipeline remat policy for this pipeline.""" - # We ensure that the decoder layer inputs are saved, although we leave it to a custom - # policy if they should be saved to device or offloaded. + """Returns the remat policy for pipeline iterations.""" if self.config.remat_policy == "custom": return self.remat_policy @@ -564,71 +516,6 @@ def get_pipeline_remat_policy(self): remat_policy = save_input_policy return remat_policy - def get_weight_sharding(self, *init_args): - """get weight sharding function for this pipeline.""" - # Returns a partition spec of all weights. Requires passing in arguments to init. - key = jax.random.PRNGKey(0) - keys = {"params": key, "dropout": key, "aqt": key} - weights = self.init(keys, *init_args) - - def get_partition_spec(pytree): - def _is_leaf(x): - return isinstance(x, nn.spmd.LogicallyPartitioned) - - def get_partition_spec_leaf(leaf): - return leaf.get_partition_spec() - - partition_spec_tree = jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) - return partition_spec_tree - - partition_spec_with_extra_layer = get_partition_spec(weights) - partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} - return partition_spec - - def get_physical_spec_no_fsdp(self, full_logical): - """ - Get physical spec without fsdp. - - TODO: Remove the expert sharding on attention weights as well, since those act like fsdp. - - Args: - full_logical: original logical partition specs of all weights - - Returns: - Modified physical spec with "fsdp" and "fsdp_transpose" removed - """ - - def remove_fsdp_sharding(sharding_tree): - def _remove_fsdp_from_partition_spec(named_sharding): - if isinstance(named_sharding, jax.sharding.NamedSharding): - new_spec = [] - for axis in named_sharding.spec: - if axis is None: - new_spec.append(None) - elif isinstance(axis, str): - if axis not in ("fsdp", "fsdp_transpose"): - new_spec.append(axis) - else: - new_spec.append(None) - elif isinstance(axis, (list, tuple)): - new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] - new_spec.append(tuple(new_axis)) - else: - raise ValueError(f"Unsupported axis type: {type(axis)}") - return jax.sharding.NamedSharding(named_sharding.mesh, jax.sharding.PartitionSpec(*new_spec)) - return named_sharding - - return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) - - physical = nn.logical_to_mesh_sharding(full_logical, mesh=self.mesh, rules=self.config.logical_axis_rules) - physical_no_fsdp = remove_fsdp_sharding(physical) - return physical_no_fsdp - - def all_gather_over_fsdp(self, sharding_info): - physical_constraint_no_fsdp = self.get_physical_spec_no_fsdp(sharding_info) - return jax.lax.with_sharding_constraint(self.layers.variables, physical_constraint_no_fsdp) - - @nn.compact def __call__( self, inputs: jnp.ndarray, @@ -636,13 +523,13 @@ def __call__( positions: jnp.ndarray, deterministic: bool, model_mode=MODEL_MODE_TRAIN, - partition_spec=None, # Pytree of sharding specifications of the weights (aka self.layers.variables) + partition_spec=None, ) -> jnp.ndarray: - """The main method that maps the series of decoder layer inputs to final layer outputs. - Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape - [global_batch], and internally this will be reshapped into microbatches. + """Maps decoder layer inputs to outputs using pipeline parallelism. + + Reshapes inputs into microbatches, runs pipeline iterations with bubble + handling, and returns outputs reshaped to original batch size. """ - # Reshape inputs of [global_batch, ...] to [microbatches, pipeline_microbatch_sizes, ...] inputs = inputs.reshape( ( self.config.num_pipeline_microbatches, @@ -651,165 +538,102 @@ def __call__( self.config.emb_dim, ) ) - example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module - # weights. + + if self._is_linen and self._linen_variables is None: + example_input = inputs[0] + example_seg_ids = segment_ids[0] if segment_ids is not None else None + example_pos = positions[0] if positions is not None else None + self._initialize_linen_parameters(example_input, example_seg_ids, example_pos, deterministic, model_mode) + ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) if positions is not None: - # AG positions positions = jax.lax.with_sharding_constraint(positions, ag_sharding) - positions = positions.reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) - example_position = jax.lax.broadcast(positions[0], [self.num_stages]) - position_idx = 0 - else: - example_position = None - position_idx = None + if segment_ids is not None: segment_ids = jax.lax.with_sharding_constraint(segment_ids, ag_sharding) segment_ids = segment_ids.reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) - example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) - segment_idx = 0 - else: - example_segmentation = None - segment_idx = None loop_state = self.init_states(inputs) - # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) - # compute to perform - # Each iteration is vmapped by num_stages, so the number of iterations should be - # num_micro * num_stages * repeats / num_stages = num_micro * repeats - # However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes - # num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional - # num_stages - 1 to finish the final repeat. - # Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble. - # The bubble doubles when we use forwarding delay. bubble_iterations = self.forwarding_delay * (self.num_stages - 1) real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats total_iterations = real_iterations + bubble_iterations - if self.is_initializing(): - vmap_func = self.get_vmap_func_for_init() - - if self.config.num_pipeline_repeats > 1: - # To shard the weights on initialization for the circular pipeline we create weights of - # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. - # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. - vmap_func = nn.vmap( - vmap_func, - in_axes=(0, segment_idx, position_idx, None, None), - variable_axes={ - "params": 0, - "_overwrite_with_gradient": 0, - "non_trainable": 0, - "hyper_params": 0, - }, - split_rngs={"params": True, "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": True, - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - }, + if self.config.scan_pipeline_iterations: + def run_iteration_scannable(loop_state, xs): + return ( + self.run_one_iteration( + loop_state, positions, segment_ids, deterministic, model_mode + ), + None, ) - example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) - example_segmentation = ( - jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) - if example_segmentation is not None - else None + if self.config.set_remat_policy_on_pipeline_iterations: + run_iteration_scannable = jax.checkpoint( + run_iteration_scannable, + prevent_cse=False, + policy=self.get_pipeline_remat_policy(), ) - example_position = ( - jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) - if example_position is not None - else None - ) - # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for - # the full total_iterations. - stage_outputs = vmap_func( - self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode - ) - if self.config.scan_layers: - stage_outputs = stage_outputs[0] - - # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output - # which has shape [pipeline_microbatch_size, sequence, embed] - if self.config.num_pipeline_repeats > 1: - stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap - broadcasted_stage_outpus = jax.lax.broadcast( - stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] - ) - return jnp.reshape( - broadcasted_stage_outpus, - [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], - ) - - if self.config.pipeline_fsdp_ag_once: - all_pipeline_weights = all_gather_over_fsdp( - self.layers.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules - ) - else: - all_pipeline_weights = self.layers.variables - - def run_iteration_scannable(model, loop_state, xs): - # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we - # explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance. - return ( - model.run_one_iteration( - loop_state, all_pipeline_weights, positions, segment_ids, deterministic, model_mode, model.layers - ), - None, - ) - if self.config.set_remat_policy_on_pipeline_iterations: - run_iteration_scannable = nn.remat( - run_iteration_scannable, - prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan - policy=self.get_pipeline_remat_policy(), - ) - - # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. - if self.config.scan_pipeline_iterations: - variable_carry = [] - variable_broadcast = [ - "params", - "_overwrite_with_gradient", - ] # All loop iterations need the weights for the full pipeline. - if self.is_mutable_collection("non_trainable"): - variable_carry.append("non_trainable") - else: - variable_broadcast.append("non_trainable") - run_all_iterations_scanned = nn.scan( - run_iteration_scannable, - variable_axes={ - "summaries": 0, - "aux_loss": 0, - "intermediates": 0, - "hyper_params": 0, - }, - variable_broadcast=variable_broadcast, - variable_carry=variable_carry, - # Dropout/aqt keys will be split for each iteration. - split_rngs={"random": True}, - length=total_iterations, - ) - loop_state, _ = run_all_iterations_scanned(self, loop_state, None) + loop_state, _ = jax.lax.scan(run_iteration_scannable, loop_state, None, length=total_iterations) else: for _ in range(total_iterations): - loop_state, _ = run_iteration_scannable(self, loop_state, None) + loop_state = self.run_one_iteration( + loop_state, positions, segment_ids, deterministic, model_mode + ) - # The final output is located in the input/output array, however the output microbatches may be permuted relative to - # the input final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) - # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] final_output = jnp.reshape( final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim) ) return final_output + + +class PipelineToLinen(nnx_wrappers.ToLinen): + """Wrap NNX Pipeline as a Linen module. + + This allows the NNX Pipeline to be used within the Linen Decoder module. + """ + pass + + +def create_pipeline( + config: Config, + layers: Callable | type, + mesh: Mesh, + remat_policy: Any = None, + use_nnx: bool = True, +) -> PipelineToLinen: + """Factory function to create a Pipeline wrapped as a Linen module. + + Args: + config: Model configuration + layers: NNX or Linen decoder layer class to use for pipeline stages + mesh: Device mesh for sharding + remat_policy: Remat policy for loop iterations + use_nnx: Whether to use NNX pipeline (True) or Linen (False) + + Returns: + PipelineToLinen wrapper around the NNX Pipeline + """ + if not use_nnx: + raise ValueError("This implementation only supports NNX pipelines (use_nnx=True)") + + wrapped = PipelineToLinen( + Pipeline, + kwargs={ + 'layers': layers, + 'config': config, + 'mesh': mesh, + 'remat_policy': remat_policy, + } + ) + + return wrapped diff --git a/src/MaxText/layers/pipeline_linen.py b/src/MaxText/layers/pipeline_linen.py new file mode 100644 index 000000000..c7284fb22 --- /dev/null +++ b/src/MaxText/layers/pipeline_linen.py @@ -0,0 +1,815 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Pipeline layer wrapping a decoder layer(s). Supports circular pipelining """ + +import functools +from typing import Any + +import numpy as np + +from jax import numpy as jnp +from jax.sharding import Mesh +import jax +import jax.ad_checkpoint + +from flax.core import meta +from flax import linen as nn + +from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT +from MaxText.sharding import all_gather_over_fsdp + + +class Pipeline(nn.Module): + """Module that implements pipelining across stages. + + This module will loop over microbatches and execute the main body with a vmap for both the inputs and weights. + This will produce a pipeline pattern if the stage dimension is sharded. + + Supports circular pipelines, and multiple layers per stage are used when a module that executes multiple layers + is passed as the layers input. + + Attributes: + config: Importantly contains num_pipeline_microbatches, num_pipeline_repeats. + layers: A module instance that each stage can execute. It can either be a single layer such as a + LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. + mesh: The device mesh of the system. + remat_policy: Remat policy to use for the loop iterations + """ + + config: Config + layers: nn.Module # The name of this property (layers) is reflected in the state pytree and thus also checkpoints. + mesh: Mesh + remat_policy: Any = None + + def setup(self): # pylint: disable=missing-function-docstring + self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 + self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches + microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages + self.microbatches_per_stage = microbatches_per_stage + self.use_circ_storage = self.need_circ_storage() + + if self.config.expert_shard_attention_option == EP_AS_CONTEXT: + self.batch_axis_name = "activation_batch_no_exp" + self.seq_len_axis_name = "activation_length" + else: + self.batch_axis_name = "activation_batch" + self.seq_len_axis_name = "activation_length_no_exp" + + def need_circ_storage(self): + return ( + self.config.num_pipeline_repeats > 1 + and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay + ) + + def iterations_to_complete_first_microbatch_one_repeat(self): + # Return the number of iterations it takes for microbatch 0 to finish a repeat + return self.forwarding_delay * (self.num_stages - 1) + + def iterations_to_complete_first_microbatch(self): + # Return the number of iterations it takes for microbatch 0 to finish the last stage of the last repeat + return ( + self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + + self.iterations_to_complete_first_microbatch_one_repeat() + ) + + def init_states(self, inputs): + """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover + Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] + + Returns a dictionary with properties + shift: zeros shape [num_stages, micro_size, sequence, embed] + prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None + state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] + circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None + circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None + loop_iteration: scalar set initially to 0. + """ + + # Shift is used to rotate the output of each pipeline into the input of the next + # shift has shape [num_stages, micro_size, sequence, embed] + shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + + shift = nn.with_logical_constraint( + shift, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) + + # Prev outputs has the same shape of the output (and shift) + if self.config.pipeline_delay_activation_forwarding: + prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + prev_outputs = nn.with_logical_constraint( + prev_outputs, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) + else: + prev_outputs = None + + # state_io (state input output) at first holds all of the input batches, but also will hold the outputs + # as the pipeline runs/finishes + # state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed] + state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) + # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. + state_io = nn.with_logical_constraint( + state_io, + ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) + + # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only + # needed when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without + # additional storage. + # circ_storage has shape [num_stages, microbatches, micro_size, sequence, embed]. + # Note that this shape is a factor of num_stages larger than necessary - each stage holds the global batch, but only + # stage 0 holds the real activations (since it will use them), the rest hold dummy ones. This amount of storage + # [global_batch, sequence, embed] is fine as long as there is some amount of additional sharding axes, e.g. FSDP, + # TP, DP (e.g. there are many devices that shard stage 0) + # We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101). + if self.use_circ_storage: + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) + else: + circ_storage = None + + # circ_storage_mover is used to push the microbatches from the pipeline into circ_storage with one buffer iteration + # of delay circ_storage_mover shape is same as shift: [num_stages, micro_size, sequence, embed] + if self.use_circ_storage: + circ_storage_mover = shift + else: + circ_storage_mover = None + + init_loop_state = { + "state_io": state_io, + "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_storage_mover, + "loop_iteration": 0, + "prev_outputs": prev_outputs, + } + return init_loop_state + + def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): + """ + Construct stages_in: the global array that is operated on for this iteration, shape same as + shift=[stages, micro_size, sequence, embed] + This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from + state_io or an old one from circ_storage + """ + + # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) + state_io_batch_idx = loop_iteration % self.microbatches_per_stage + state_io_slice = state_io[:, state_io_batch_idx] + + if self.use_circ_storage: + # Setup potential input from circ_storage, which also has a rotating index for microbatch, + # size of num_microbatches + circ_storage_batch_idx = loop_iteration % self.config.num_pipeline_microbatches + circular_stage_in = circ_storage[:, circ_storage_batch_idx] + else: + # The last stage immediately flows into the first stage, use this rotated shift instead of circular storage + circular_stage_in = shift + + # For early loop iterations we grab a new input for stage 0 from the state_io. Once each microbatch has left + # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. + # from circ_storage). + first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) + + # Note that first_stage_in may correspond to bubble computation during the last few iterations. + # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are + # thus discarded / not returned. + # The final returned output is stored in the state_io, which has the appropriate total size of num_microbatches. The + # state_io will not contain bubble results at the end of the last iteration. + + def select_state_or_input(first_stage_in, shift): + # Selects input for stage 0, shift for other stages + return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) + + # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) + stages_in = select_state_or_input(first_stage_in, shift) + stages_in = nn.with_logical_constraint( + stages_in, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + rules=self.config.logical_axis_rules, + mesh=self.mesh, + ) + return stages_in + + def shard_dim_by_stages(self, x, dim: int): + # Shards a dimension by stages. Currently, the sharding of other dimensions are left up the compiler, alternatively + # we may want to copy over the sharding from the other input axes. + dims_mapping = [jax.sharding.PartitionSpec.UNCONSTRAINED] * x.ndim + dims_mapping[dim] = "stage" + dims_mapping = tuple(dims_mapping) + sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) + return jax.lax.with_sharding_constraint(x, sharding) + + def get_microbatch_and_repeat_ids(self, loop_iteration): + """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and + non-circular""" + # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages + microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches + repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + + def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Use vmap to implement a sharded parallel gather. + Parallel gather means each stage has its own weights, and gets one slice from it. + Args: + weights: Per-stage data to be gathered from. + repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. + repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not + have this dimension. + stages_dim_in_weights: The dimension in weights that represents parallel stages. + Returns: + The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights + removed. + """ + + def _gather_one(x, repeat_id): + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) + + gathered_weights_stage_dim = 0 + repeat_ids = self.shard_dim_by_stages(repeat_ids, 0) + weights = self.shard_dim_by_stages(weights, stages_dim_in_weights) + stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( + weights, repeat_ids + ) + stage_weights = self.shard_dim_by_stages(stage_weights, gathered_weights_stage_dim) + return stage_weights + + def vmap_gather(self, xs, ids, ids_dim): + """Use vmap to implement a stage-wise sharded gather. + + The stages share the same input, but they have different offsets. + + Args: + xs: Data shared by all stages, to be gathered from. + ids: Integer tensor of shape [num_stages], the offsets of the stages. + ids_dim: The dimension in xs where ids are applied. In the output, this + dimension will be [num_stages], since each stage gets one slice. + + Returns: + The per-stage gathered values. The shape is xs.shape but with ids_dim size + replaced with [num_stages]. + """ + + def _gather_one(x, i): + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) + + ids = self.shard_dim_by_stages(ids, 0) + outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) + return self.shard_dim_by_stages(outs, 0) + + def get_new_loop_state(self, output, loop_state): + """ + Update the various buffers given the output of the most recent iteration + * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) + * Pushing inputs up from top of state_io into first stage of shift + * Pulling outputs up from last stage of shift into bottom of state_io + * shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to + right/down + * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage + * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration + * prev_outputs: is set to the current output + """ + + old_state_io = loop_state["state_io"] + old_circ_storage = loop_state["circ_storage"] + old_circ_storage_mover = loop_state["circ_storage_mover"] + loop_iteration = loop_state["loop_iteration"] + old_prev_outputs = loop_state["prev_outputs"] + + def _rotate_right(arr): + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) + except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) + + def _shift_right(arr): + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) + + # Shift either rotates or shifts depending on if the last stage immediately must send to first or not + # For non-circular pipelines, the last stage does not need to send to first + # For circular pipelines with #micro = #stages, last stage immediately sends to first + # For circular pipelines with #micro > stages (circ_storage), last stage sends to circ storage + def _update_shift(output_in): + if self.config.num_pipeline_repeats == 1 or self.use_circ_storage: + return _shift_right(output_in) # last stage does not have to send to first immediately + else: + return _rotate_right(output_in) # last stage must immediately send to first + + if self.config.pipeline_delay_activation_forwarding: + new_shift = _update_shift(old_prev_outputs) + new_prev_outputs = output + else: + new_shift = _update_shift(output) + new_prev_outputs = None + + if self.use_circ_storage: + # Insert the circ_storage_mover into new_circ_storage at a microbatch-rotating index. + # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped + # compute/async transfers + def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): + rotated = _rotate_right(circ_storage_mover_in) + rotated = jnp.expand_dims(rotated, 1) + # We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0 + offset = ( + loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 + ) % self.config.num_pipeline_microbatches # Note extra -1 b/c grabbing from the + # previous output - using circ_storage_mover before it is updated + return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) + + new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) + new_circ_storage_mover = output + else: + new_circ_storage = None + new_circ_storage_mover = None + + # Rotate stream_io left/up by 1 on rotating micro/stage index (stream_buf_idx), replacing the last/bottom with the + # last stage output + stream_buf_idx = loop_iteration % self.microbatches_per_stage + stream_slice = old_state_io[:, stream_buf_idx] + + def _update_state_io(state_in, stream_slice, output): + # Shift the current slice to the left, then fill the last stage with the final output. + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) + stream_slice = jnp.expand_dims(stream_slice, 1) + return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) + + new_state = _update_state_io(old_state_io, stream_slice, output) + + new_loop_state = { + "state_io": new_state, + "shift": new_shift, + "circ_storage": new_circ_storage, + "circ_storage_mover": new_circ_storage_mover, + "loop_iteration": loop_iteration + 1, + "prev_outputs": new_prev_outputs, + } + return new_loop_state + + def permute_output_micro_per_stage_dim(self, output): + # The first real output (microbatch 0) takes a certain amount of loop iterations to finish and be pushed to + # state_io - it will land on a different index of state_io depending on the number of iterations. + microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage + permutation = ( + np.arange(self.microbatches_per_stage) + microbatch_0_idx + ) % self.microbatches_per_stage # permute so the value in land_idx is moved into idx 0, and (land_idx + 1) appear + # in idx 1, etc + output = output[:, permutation] + return output + + def get_current_stage_weights(self, pipeline_weights, loop_iteration): + """ + Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. + {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. + For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However + for circular pipelines each stage grabs only the weights corresponding to the current repeat. + """ + if self.config.num_pipeline_repeats > 1: + return self.get_current_repeat_from_stages(pipeline_weights, loop_iteration) + else: + return pipeline_weights + + def get_current_repeat_from_stages(self, weights, loop_iteration): + """get current repeat from stages""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def gather_weights_for_stages_in(weights): + return jax.tree.map( + functools.partial( + self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ), + weights, + ) + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + weights = meta.remove_axis( + weights, 0, circular_metadata_params + ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular + # entry per stage. + weights = gather_weights_for_stages_in(weights) + return weights + + def get_vmap_func_for_init(self): + """This vmap func is used to initialize the weights only on init.""" + + def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + """nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance.""" + return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + vmap_func = nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name="stage", + variable_axes={"params": 0, "_overwrite_with_gradient": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return vmap_func + + def get_main_vmap_func_for_iterations(self): + """ + Returns main stage function vmapped by number of stages. + This becomes a vmap over a single layer instance if body_instance is a single layer, + else a set of layers if body_instance is a set of layers. + """ + + def func_to_vmap( + body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode + ): + """nn.vmap requires either a nn.module class or a function whose first argument is a nn.module instance.""" + weights = meta.remove_axis( + weights, + 0, + { + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + vmap_func = nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, 0, None, None), + spmd_axis_name="stage", + variable_axes={"params": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return vmap_func + + def run_one_iteration( + self, loop_state, pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance + ): + """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, + and update the loop state.""" + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire + # buffer. + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + + vmap_func = self.get_main_vmap_func_for_iterations() + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def prepare_vars_for_main_vmap(weights): + def gather_weights_for_stages_in(weights): + return jax.tree.map( + functools.partial( + self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ), + weights, + ) + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + weights = meta.remove_axis( + weights, 0, circular_metadata_params + ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one + # circular entry per stage. + weights = gather_weights_for_stages_in(weights) + return weights + + vmap_func = nn.map_variables( + vmap_func, + mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], + mutable=True, + trans_in_fn=prepare_vars_for_main_vmap, + ) + + stage_weights = self.get_current_stage_weights(pipeline_weights, loop_iteration) + stages_output = vmap_func( + decoder_layer_instance, + stage_weights, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ) + if self.config.scan_layers: + stages_output = stages_output[0] + + new_state = self.get_new_loop_state(stages_output, loop_state) + return new_state + + def get_pipeline_remat_policy(self): + """Returns the pipeline remat policy for this pipeline.""" + # We ensure that the decoder layer inputs are saved, although we leave it to a custom + # policy if they should be saved to device or offloaded. + if self.config.remat_policy == "custom": + return self.remat_policy + + save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + if self.remat_policy is not None: + remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + else: + remat_policy = save_input_policy + return remat_policy + + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + # Returns a partition spec of all weights. Requires passing in arguments to init. + key = jax.random.PRNGKey(0) + keys = {"params": key, "dropout": key, "aqt": key} + weights = self.init(keys, *init_args) + + def get_partition_spec(pytree): + def _is_leaf(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) + + def get_partition_spec_leaf(leaf): + return leaf.get_partition_spec() + + partition_spec_tree = jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) + return partition_spec_tree + + partition_spec_with_extra_layer = get_partition_spec(weights) + partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} + return partition_spec + + def get_physical_spec_no_fsdp(self, full_logical): + """ + Get physical spec without fsdp. + + TODO: Remove the expert sharding on attention weights as well, since those act like fsdp. + + Args: + full_logical: original logical partition specs of all weights + + Returns: + Modified physical spec with "fsdp" and "fsdp_transpose" removed + """ + + def remove_fsdp_sharding(sharding_tree): + def _remove_fsdp_from_partition_spec(named_sharding): + if isinstance(named_sharding, jax.sharding.NamedSharding): + new_spec = [] + for axis in named_sharding.spec: + if axis is None: + new_spec.append(None) + elif isinstance(axis, str): + if axis not in ("fsdp", "fsdp_transpose"): + new_spec.append(axis) + else: + new_spec.append(None) + elif isinstance(axis, (list, tuple)): + new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] + new_spec.append(tuple(new_axis)) + else: + raise ValueError(f"Unsupported axis type: {type(axis)}") + return jax.sharding.NamedSharding(named_sharding.mesh, jax.sharding.PartitionSpec(*new_spec)) + return named_sharding + + return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) + + physical = nn.logical_to_mesh_sharding(full_logical, mesh=self.mesh, rules=self.config.logical_axis_rules) + physical_no_fsdp = remove_fsdp_sharding(physical) + return physical_no_fsdp + + def all_gather_over_fsdp(self, sharding_info): + physical_constraint_no_fsdp = self.get_physical_spec_no_fsdp(sharding_info) + return jax.lax.with_sharding_constraint(self.layers.variables, physical_constraint_no_fsdp) + + @nn.compact + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: jnp.ndarray, + positions: jnp.ndarray, + deterministic: bool, + model_mode=MODEL_MODE_TRAIN, + partition_spec=None, # Pytree of sharding specifications of the weights (aka self.layers.variables) + ) -> jnp.ndarray: + """The main method that maps the series of decoder layer inputs to final layer outputs. + Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape + [global_batch], and internally this will be reshapped into microbatches. + """ + # Reshape inputs of [global_batch, ...] to [microbatches, pipeline_microbatch_sizes, ...] + inputs = inputs.reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ) + ) + example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module + # weights. + ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) + if positions is not None: + # AG positions + positions = jax.lax.with_sharding_constraint(positions, ag_sharding) + + positions = positions.reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + example_position = jax.lax.broadcast(positions[0], [self.num_stages]) + position_idx = 0 + else: + example_position = None + position_idx = None + if segment_ids is not None: + segment_ids = jax.lax.with_sharding_constraint(segment_ids, ag_sharding) + segment_ids = segment_ids.reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) + segment_idx = 0 + else: + example_segmentation = None + segment_idx = None + + loop_state = self.init_states(inputs) + + # Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) + # compute to perform + # Each iteration is vmapped by num_stages, so the number of iterations should be + # num_micro * num_stages * repeats / num_stages = num_micro * repeats + # However due to the pipeline bubble some iterations process less than num_stages microbatches. It takes + # num_micro * repeat iterations for the last microbatch to start the final repeat, then an additional + # num_stages - 1 to finish the final repeat. + # Thus the total iterations is num_micro * repeat + num_stages - 1, & we may consider the num_stages - 1 as bubble. + # The bubble doubles when we use forwarding delay. + bubble_iterations = self.forwarding_delay * (self.num_stages - 1) + real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + total_iterations = real_iterations + bubble_iterations + + if self.is_initializing(): + vmap_func = self.get_vmap_func_for_init() + + if self.config.num_pipeline_repeats > 1: + # To shard the weights on initialization for the circular pipeline we create weights of + # shape [num_repeat, num_stages, ...] (e.g. [num_repeat, num_stages, embed, mlp]) and shard the num_stages axis. + # We wrap the main stage vmap with a num_repeat vmap to generate this axis only for parameter initialization. + vmap_func = nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={ + "params": 0, + "_overwrite_with_gradient": 0, + "non_trainable": 0, + "hyper_params": 0, + }, + split_rngs={"params": True, "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + }, + ) + + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = ( + jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) + if example_segmentation is not None + else None + ) + example_position = ( + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) + if example_position is not None + else None + ) + # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for + # the full total_iterations. + stage_outputs = vmap_func( + self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode + ) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + + # We return something of the correct shape (global_batch, sequence, embed) by reshaping a single stages output + # which has shape [pipeline_microbatch_size, sequence, embed] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] # Remove extra dimension created for the circular vmap + broadcasted_stage_outpus = jax.lax.broadcast( + stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] + ) + return jnp.reshape( + broadcasted_stage_outpus, + [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + ) + + if self.config.pipeline_fsdp_ag_once: + all_pipeline_weights = all_gather_over_fsdp( + self.layers.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules + ) + else: + all_pipeline_weights = self.layers.variables + + def run_iteration_scannable(model, loop_state, xs): + # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we + # explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance. + return ( + model.run_one_iteration( + loop_state, all_pipeline_weights, positions, segment_ids, deterministic, model_mode, model.layers + ), + None, + ) + + if self.config.set_remat_policy_on_pipeline_iterations: + run_iteration_scannable = nn.remat( + run_iteration_scannable, + prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan + policy=self.get_pipeline_remat_policy(), + ) + + # The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized. + if self.config.scan_pipeline_iterations: + variable_carry = [] + variable_broadcast = [ + "params", + "_overwrite_with_gradient", + ] # All loop iterations need the weights for the full pipeline. + if self.is_mutable_collection("non_trainable"): + variable_carry.append("non_trainable") + else: + variable_broadcast.append("non_trainable") + run_all_iterations_scanned = nn.scan( + run_iteration_scannable, + variable_axes={ + "summaries": 0, + "aux_loss": 0, + "intermediates": 0, + "hyper_params": 0, + }, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + # Dropout/aqt keys will be split for each iteration. + split_rngs={"random": True}, + length=total_iterations, + ) + loop_state, _ = run_all_iterations_scanned(self, loop_state, None) + else: + for _ in range(total_iterations): + loop_state, _ = run_iteration_scannable(self, loop_state, None) + + # The final output is located in the input/output array, however the output microbatches may be permuted relative to + # the input + final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) + + # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] + final_output = jnp.reshape( + final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim) + ) + + return final_output diff --git a/tests/pipeline_parallelism_test.py b/tests/pipeline_parallelism_test.py index 43efb62ca..da078fc7e 100644 --- a/tests/pipeline_parallelism_test.py +++ b/tests/pipeline_parallelism_test.py @@ -40,7 +40,7 @@ def assert_same_output_and_grad(f1, f2, *inputs): - """check that the output and gradient are the same""" + """Checks that the output and gradient are the same.""" f1_value, f1_grad = jax.value_and_grad(f1)(*inputs) f2_value, f2_grad = jax.value_and_grad(f2)(*inputs) @@ -59,7 +59,7 @@ def pytree_ravel(pytree): class PipelineParallelismTest(unittest.TestCase): def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_class=None): - """check that the output and gradient are the same""" + """Checks that the output and gradient are the same.""" devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) model_mode = MODEL_MODE_TRAIN @@ -72,19 +72,17 @@ def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_cla single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode) def get_inputs(batch_size, sequence, features): - """Get random inputs, and random dummy targets - Returns - inputs: [batch_size, sequence, features] - targets: [batch_size, sequence, features] - positions: [batch_size, sequence] - segmentations: [batch_size, segmentation] + """Get random inputs and dummy targets for gradient testing. + + Returns: + inputs: [batch_size, sequence, features] + targets: [batch_size, sequence, features] + positions: [batch_size, sequence] + segmentations: [batch_size, segmentation] """ input_shape = [batch_size, sequence, features] inputs = jax.random.normal(jax.random.PRNGKey(2), input_shape, dtype=jnp.float32) - - # dummy targets same shape as inputs to use for a dummy loss function to check gradient correctness dummy_targets = jax.random.normal(jax.random.PRNGKey(3), input_shape, dtype=jnp.float32) - inputs_position = jnp.array([jnp.arange(sequence, dtype=jnp.int32) for _ in range(batch_size)], dtype=jnp.int32) inputs_segmentation = jnp.ones((batch_size, sequence), dtype=jnp.int32) return inputs, dummy_targets, inputs_position, inputs_segmentation @@ -93,20 +91,11 @@ def get_inputs(batch_size, sequence, features): config.global_batch_size_to_train_on, config.max_target_length, config.emb_dim ) deterministic = True - # We use a simpler single matmul decoder layer for fast compilation in these tests. - rngs = nnx.Rngs(params=0) - single_pipeline_stage = simple_layer.SimpleDecoderLayerToLinen( - config=config, mesh=mesh, model_mode=model_mode, rngs=rngs - ) + single_pipeline_stage = simple_layer.SimpleDecoderLayer + my_pipeline = pipeline.Pipeline(config=config, layers=single_pipeline_stage, mesh=mesh) - init_pipeline_params = my_pipeline.init( - jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode - ) - partition_spec = my_pipeline.get_weight_sharding( - inputs, inputs_position, inputs_segmentation, deterministic, model_mode - ) + graphdef, params, rest_state = nnx.split(my_pipeline, nnx.Param, ...) - # Create a dummy scalar loss function so we may take the gradient wrt weights def pipeline_parallelism_dummy_loss_extra( params, inputs, @@ -115,47 +104,24 @@ def pipeline_parallelism_dummy_loss_extra( deterministic, model_mode, dummy_targets, - partition_spec=None, ): - outputs = my_pipeline.apply( - params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, partition_spec=partition_spec - ) + pipeline_instance = nnx.merge(graphdef, params, rest_state) + outputs = pipeline_instance(inputs, inputs_position, inputs_segmentation, deterministic, model_mode) loss = jnp.linalg.norm(outputs - dummy_targets) return loss - pipeline_parallelism_dummy_loss = functools.partial( - pipeline_parallelism_dummy_loss_extra, partition_spec=partition_spec - ) + pipeline_parallelism_dummy_loss = pipeline_parallelism_dummy_loss_extra def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): - def get_cur_layer_params(params, layer_idx): - def get_cur_layer_params_arr(leaf): - # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] - if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage == 1: - new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] - leaf = jnp.reshape(leaf, new_shape) # [repeat, stage] -> [layers] - elif config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: - new_shape = (leaf.shape[0] * leaf.shape[1] * leaf.shape[2],) + leaf.shape[3:] - leaf = jnp.reshape(leaf, new_shape) # [repeat, stage, layers_per_stage] -> [layers] - elif config.num_pipeline_repeats == 1 and config.num_layers_per_pipeline_stage > 1: - new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] - leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] - return leaf[layer_idx] - - return jax.tree.map(get_cur_layer_params_arr, params) - + pipeline_instance = nnx.merge(graphdef, params, rest_state) reg_layer_activations = inputs - for layer in range(config.num_decoder_layers): - cur_layer_params = get_cur_layer_params(params, layer) - cur_layer_params["params"] = cur_layer_params["params"]["layers"] - if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: - cur_layer_params["params"] = meta.remove_axis( - cur_layer_params["params"], 0, {nn.PARTITION_NAME: "circular_repeats"} + num_stages = config.ici_pipeline_parallelism * config.dcn_pipeline_parallelism + for repeat in range(config.num_pipeline_repeats): + for s in range(num_stages): + stage = getattr(pipeline_instance, f'stage_{s}') + reg_layer_activations, _ = stage( + reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode ) - cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"}) - reg_layer_activations, _ = single_pipeline_stage.apply( - cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode - ) return reg_layer_activations def regular_sequential_layers_dummy_loss( @@ -168,7 +134,7 @@ def regular_sequential_layers_dummy_loss( assert_same_output_and_grad( regular_sequential_layers_dummy_loss, pipeline_parallelism_dummy_loss, - init_pipeline_params, + params, inputs, inputs_segmentation, inputs_position, @@ -179,7 +145,7 @@ def regular_sequential_layers_dummy_loss( @pytest.mark.tpu_only def test_circular_minimum_microbatches_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches + """Tests 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches.""" config = pyconfig.initialize( [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False, @@ -196,7 +162,7 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): @pytest.mark.tpu_only def test_circular_extra_microbatches_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + """Tests 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches.""" config = pyconfig.initialize( [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False, @@ -213,7 +179,7 @@ def test_circular_extra_microbatches_same_output_and_grad(self): @pytest.mark.tpu_only def test_circular_deepseek_megablox_same_output_and_grad(self): - # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + """Tests 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches with DeepSeek.""" config = pyconfig.initialize( [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False, @@ -236,7 +202,7 @@ def test_circular_deepseek_megablox_same_output_and_grad(self): @pytest.mark.tpu_only def test_circular_ag_once(self): - # 2 stages, 8 microbatches, all gather once + """Tests 2 stages, 8 microbatches, all gather once.""" config = pyconfig.initialize( [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False, @@ -254,7 +220,7 @@ def test_circular_ag_once(self): @pytest.mark.tpu_only def test_non_circular_same_output_and_grad(self): - # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches + """Tests 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches.""" config = pyconfig.initialize( [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False, @@ -271,7 +237,7 @@ def test_non_circular_same_output_and_grad(self): @pytest.mark.integration_test @pytest.mark.tpu_only def test_full_train_circular(self): - # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches + """Full train.py test with 4 stages, 32 layers (2 per stage, 4 repeats), 8 microbatches.""" train_main( [ None, @@ -296,13 +262,13 @@ def test_full_train_circular(self): "num_layers_per_pipeline_stage=2", "num_pipeline_microbatches=8", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", - "scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations. + "scan_layers_per_stage=False", ] ) @pytest.mark.tpu_only def test_delay_activation_forwarding_same_output_and_grad(self): - # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + """Tests 4 stages with delayed activation forwarding, 8 layers (2 repeats), 8 microbatches.""" config = pyconfig.initialize( [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], enable_checkpointing=False, @@ -321,7 +287,7 @@ def test_delay_activation_forwarding_same_output_and_grad(self): @pytest.mark.integration_test @pytest.mark.tpu_only def test_full_train_non_circular(self): - # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches + """Full train.py test with 4 stages, 32 layers (8 per stage), 8 microbatches.""" train_main( [ None, @@ -346,14 +312,14 @@ def test_full_train_non_circular(self): "num_layers_per_pipeline_stage=8", "num_pipeline_microbatches=8", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", - "scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations. + "scan_layers_per_stage=False", ] ) @pytest.mark.integration_test @pytest.mark.tpu_only def test_subset_layers(self): - # Run a full train.py call with 4 stages, 16 layers - 8 in pipeline, 8 ran outside of pipeline + """Full train.py test with 4 stages, 16 layers (8 in pipeline, 8 outside).""" train_main( [ None, @@ -380,14 +346,13 @@ def test_subset_layers(self): "pipeline_parallel_layers=8", "num_pipeline_microbatches=8", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", - "scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations. + "scan_layers_per_stage=False", ] ) @pytest.mark.integration_test def test_full_train_fp8(self): - # Run a full train.py call with fp8 quantization, which adds extra - # variable collections that need to be handled + """Full train.py test with fp8 quantization.""" train_main( [ None, @@ -418,8 +383,7 @@ def test_full_train_fp8(self): @pytest.mark.integration_test def test_full_train_nanoo_fp8(self): - # Run a full train.py call with NANOO fp8 quantization, which adds extra - # variable collections that need to be handled + """Full train.py test with NANOO fp8 quantization.""" train_main( [ None, diff --git a/tests/pipeline_parallelism_test_linen.py b/tests/pipeline_parallelism_test_linen.py new file mode 100644 index 000000000..43efb62ca --- /dev/null +++ b/tests/pipeline_parallelism_test_linen.py @@ -0,0 +1,453 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for pipeline parallelism.""" + +import functools +import os.path +import sys +import unittest + +import pytest + +import jax +from jax.sharding import Mesh +import jax.numpy as jnp + +from flax.core import meta +from flax import linen as nn +from flax import nnx + +from MaxText import maxtext_utils +from MaxText import pyconfig +from MaxText.common_types import MODEL_MODE_TRAIN +from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT +from MaxText.layers import pipeline +from MaxText.layers import simple_layer +from MaxText.train import main as train_main +from MaxText.layers import deepseek + + +def assert_same_output_and_grad(f1, f2, *inputs): + """check that the output and gradient are the same""" + f1_value, f1_grad = jax.value_and_grad(f1)(*inputs) + f2_value, f2_grad = jax.value_and_grad(f2)(*inputs) + + def pytree_ravel(pytree): + ravelled_tree = jax.tree.map(jnp.ravel, pytree) + ravelled_leaves, _ = jax.tree_util.tree_flatten(ravelled_tree) + return jnp.concatenate(ravelled_leaves) + + f1_grad = pytree_ravel(f1_grad) + f2_grad = pytree_ravel(f2_grad) + + assert jax.numpy.allclose(f1_value, f2_value, rtol=1e-2, equal_nan=False) + assert jax.numpy.allclose(f1_grad, f2_grad, rtol=1e-1, equal_nan=False) + + +class PipelineParallelismTest(unittest.TestCase): + + def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_class=None): + """check that the output and gradient are the same""" + devices_array = maxtext_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + model_mode = MODEL_MODE_TRAIN + if single_pipeline_stage_class is None: + rngs = nnx.Rngs(params=0) + single_pipeline_stage = simple_layer.SimpleDecoderLayerToLinen( + config=config, mesh=mesh, model_mode=model_mode, rngs=rngs + ) + else: + single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode) + + def get_inputs(batch_size, sequence, features): + """Get random inputs, and random dummy targets + Returns + inputs: [batch_size, sequence, features] + targets: [batch_size, sequence, features] + positions: [batch_size, sequence] + segmentations: [batch_size, segmentation] + """ + input_shape = [batch_size, sequence, features] + inputs = jax.random.normal(jax.random.PRNGKey(2), input_shape, dtype=jnp.float32) + + # dummy targets same shape as inputs to use for a dummy loss function to check gradient correctness + dummy_targets = jax.random.normal(jax.random.PRNGKey(3), input_shape, dtype=jnp.float32) + + inputs_position = jnp.array([jnp.arange(sequence, dtype=jnp.int32) for _ in range(batch_size)], dtype=jnp.int32) + inputs_segmentation = jnp.ones((batch_size, sequence), dtype=jnp.int32) + return inputs, dummy_targets, inputs_position, inputs_segmentation + + inputs, dummy_targets, inputs_position, inputs_segmentation = get_inputs( + config.global_batch_size_to_train_on, config.max_target_length, config.emb_dim + ) + deterministic = True + # We use a simpler single matmul decoder layer for fast compilation in these tests. + rngs = nnx.Rngs(params=0) + single_pipeline_stage = simple_layer.SimpleDecoderLayerToLinen( + config=config, mesh=mesh, model_mode=model_mode, rngs=rngs + ) + my_pipeline = pipeline.Pipeline(config=config, layers=single_pipeline_stage, mesh=mesh) + init_pipeline_params = my_pipeline.init( + jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode + ) + partition_spec = my_pipeline.get_weight_sharding( + inputs, inputs_position, inputs_segmentation, deterministic, model_mode + ) + + # Create a dummy scalar loss function so we may take the gradient wrt weights + def pipeline_parallelism_dummy_loss_extra( + params, + inputs, + inputs_position, + inputs_segmentation, + deterministic, + model_mode, + dummy_targets, + partition_spec=None, + ): + outputs = my_pipeline.apply( + params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, partition_spec=partition_spec + ) + loss = jnp.linalg.norm(outputs - dummy_targets) + return loss + + pipeline_parallelism_dummy_loss = functools.partial( + pipeline_parallelism_dummy_loss_extra, partition_spec=partition_spec + ) + + def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): + def get_cur_layer_params(params, layer_idx): + def get_cur_layer_params_arr(leaf): + # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage == 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage] -> [layers] + elif config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1] * leaf.shape[2],) + leaf.shape[3:] + leaf = jnp.reshape(leaf, new_shape) # [repeat, stage, layers_per_stage] -> [layers] + elif config.num_pipeline_repeats == 1 and config.num_layers_per_pipeline_stage > 1: + new_shape = (leaf.shape[0] * leaf.shape[1],) + leaf.shape[2:] + leaf = jnp.reshape(leaf, new_shape) # [stage, layers_per_stage] -> [layers] + return leaf[layer_idx] + + return jax.tree.map(get_cur_layer_params_arr, params) + + reg_layer_activations = inputs + for layer in range(config.num_decoder_layers): + cur_layer_params = get_cur_layer_params(params, layer) + cur_layer_params["params"] = cur_layer_params["params"]["layers"] + if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + cur_layer_params["params"] = meta.remove_axis( + cur_layer_params["params"], 0, {nn.PARTITION_NAME: "circular_repeats"} + ) + cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"}) + reg_layer_activations, _ = single_pipeline_stage.apply( + cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode + ) + return reg_layer_activations + + def regular_sequential_layers_dummy_loss( + params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets + ): + outputs = regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + loss = jnp.linalg.norm(outputs - dummy_targets) + return loss + + assert_same_output_and_grad( + regular_sequential_layers_dummy_loss, + pipeline_parallelism_dummy_loss, + init_pipeline_params, + inputs, + inputs_segmentation, + inputs_position, + deterministic, + model_mode, + dummy_targets, + ) + + @pytest.mark.tpu_only + def test_circular_minimum_microbatches_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + enable_checkpointing=False, + enable_goodput_recording=False, + run_name="circular_minimum_microbatches", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=4, + per_device_batch_size=4, + ) + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.tpu_only + def test_circular_extra_microbatches_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + enable_checkpointing=False, + enable_goodput_recording=False, + run_name="circular_extra_microbatches", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=8, + per_device_batch_size=4, + ) + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.tpu_only + def test_circular_deepseek_megablox_same_output_and_grad(self): + # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + enable_checkpointing=False, + enable_goodput_recording=False, + run_name="circular_moe", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=8, + per_device_batch_size=4, + num_experts=4, + num_experts_per_tok=2, + megablox=False, + sparse_matmul=False, + capacity_factor=1, + decoder_block="deepseek", + ) + self.assert_pipeline_same_output_and_grad(config, single_pipeline_stage_class=deepseek.DeepSeekMoELayer) + + @pytest.mark.tpu_only + def test_circular_ag_once(self): + # 2 stages, 8 microbatches, all gather once + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + enable_checkpointing=False, + enable_goodput_recording=False, + run_name="circular_ag_once", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=2, + base_num_decoder_layers=8, + num_pipeline_microbatches=8, + per_device_batch_size=4, + pipeline_fsdp_ag_once=True, + ) + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.tpu_only + def test_non_circular_same_output_and_grad(self): + # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + enable_checkpointing=False, + run_name="non_circular", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=4, + num_pipeline_microbatches=4, + per_device_batch_size=4, + ) + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.integration_test + @pytest.mark.tpu_only + def test_full_train_circular(self): + # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches + train_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=2", + "num_pipeline_microbatches=8", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations. + ] + ) + + @pytest.mark.tpu_only + def test_delay_activation_forwarding_same_output_and_grad(self): + # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + enable_checkpointing=False, + enable_goodput_recording=False, + run_name="activation_forwarding", + max_target_length=128, + base_emb_dim=28, + ici_pipeline_parallelism=4, + base_num_decoder_layers=8, + num_pipeline_microbatches=8, + per_device_batch_size=4, + pipeline_delay_activation_forwarding=True, + ) + self.assert_pipeline_same_output_and_grad(config) + + @pytest.mark.integration_test + @pytest.mark.tpu_only + def test_full_train_non_circular(self): + # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches + train_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=8", + "num_pipeline_microbatches=8", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations. + ] + ) + + @pytest.mark.integration_test + @pytest.mark.tpu_only + def test_subset_layers(self): + # Run a full train.py call with 4 stages, 16 layers - 8 in pipeline, 8 ran outside of pipeline + train_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_test", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=16", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=4", + "num_layers_per_pipeline_stage=1", + "num_pipeline_repeats=2", + "pipeline_parallel_layers=8", + "num_pipeline_microbatches=8", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "scan_layers_per_stage=False", # We see better performance only scanning the pipeline iterations. + ] + ) + + @pytest.mark.integration_test + def test_full_train_fp8(self): + # Run a full train.py call with fp8 quantization, which adds extra + # variable collections that need to be handled + train_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_fp8_test", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=4", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=4", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "quantization=fp8", + "scan_layers_per_stage=False", + "attention=dot_product", + ] + ) + + @pytest.mark.integration_test + def test_full_train_nanoo_fp8(self): + # Run a full train.py call with NANOO fp8 quantization, which adds extra + # variable collections that need to be handled + train_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_pipeline_parallelism_nanoo_fp8_test", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=4", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=4", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "quantization=nanoo_fp8", + "scan_layers_per_stage=False", + "attention=dot_product", + ] + ) + + +if __name__ == "__main__": + unittest.main()