diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 81879a1a7..07c46be53 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -34,8 +34,8 @@ from MaxText.layers.embeddings import Embed, embed_as_linen from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock -from MaxText.sharding import all_gather_over_fsdp +from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen +from MaxText.maxtext_utils import all_gather_over_fsdp # ------------------------------------------------------------------------------ # The network: Transformer Definitions @@ -94,8 +94,12 @@ def setup(self): # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. mtp_layer = layer_types[-1] - self.mtp_block = MultiTokenPredictionBlock( - config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder + self.mtp_block = multi_token_prediction_block_as_linen( + config=self.config, + mesh=self.mesh, + transformer_layer_module=mtp_layer, + decoder=self.decoder, + rngs=self.make_rng("mtp_block"), ) def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): @@ -285,7 +289,15 @@ class Transformer(nnx.Module): # Make new attributes required, so that all Transformer dependencies (train, decode, # compile, etc) will error instead of silently use defaults. # pylint: disable=attribute-defined-outside-init - def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs): + def __init__( + self, + config: Config, + mesh: Mesh, + quant: Quant, + *, + model_mode: str = MODEL_MODE_TRAIN, + rngs: nnx.Rngs, + ): """Initialize shared_embedding & decoder layers.""" self.config = config self.mesh = mesh @@ -347,8 +359,13 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. mtp_layer = layer_types[-1] - mtp_block_linen = MultiTokenPredictionBlock( - config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder + mtp_block_linen = multi_token_prediction_block_as_linen( + config=self.config, + mesh=self.mesh, + transformer_layer_module=mtp_layer, + decoder=self.decoder, + rngs=rngs, + name="mtp_block", ) self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) @@ -593,7 +610,10 @@ def __call__( page_state=page_state, ) all_model_weights = all_gather_over_fsdp( - self.model.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules + self.model.variables, + partition_spec, + mesh=self.mesh, + logical_axis_rules=self.config.logical_axis_rules, ) return self.model.apply( diff --git a/src/MaxText/layers/multi_token_prediction.py b/src/MaxText/layers/multi_token_prediction.py index 24c084049..a3201de36 100644 --- a/src/MaxText/layers/multi_token_prediction.py +++ b/src/MaxText/layers/multi_token_prediction.py @@ -21,140 +21,141 @@ from jax.sharding import Mesh from flax import linen as nn +from flax import nnx from MaxText.common_types import Config, MODEL_MODE_TRAIN -from MaxText.layers.linears import dense_general -from MaxText.layers.normalizations import rms_norm -from MaxText.layers.decoders import Decoder, DecoderLayer +from MaxText.layers.linears import DenseGeneral +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.decoders import DecoderLayer +from MaxText.layers import nnx_wrappers from MaxText import max_utils from MaxText import maxtext_utils from MaxText.globals import EPS +from MaxText.layers.initializers import variable_to_logically_partitioned + + +# Custom Variable types for MTP intermediate outputs +# These will be automatically converted to Linen mutable collections by ToLinen wrapper +# The class names become collection names directly (no case conversion) +class mtp_losses(nnx.Variable): # pylint: disable=invalid-name + """Variable type for storing MTP loss components -> 'mtp_losses' collection.""" + + +class mtp_acceptance(nnx.Variable): # pylint: disable=invalid-name + """Variable type for storing MTP acceptance predictions -> 'mtp_acceptance' collection.""" def roll_and_mask(x: jnp.ndarray, shift: int = -1) -> jnp.ndarray: - """ - Performs a leftward roll on the sequence axis (axis=1) and masks the - newly created invalid positions at the end of the sequence. - Assumes input `x` has a batch dimension at axis 0 and sequence at axis 1. + """Performs a leftward roll on sequence axis and masks invalid positions. Args: - x: The input array of shape [batch, seq_len, ...]. - shift: The number of positions to shift left. + x: Input array of shape [batch, seq_len, ...]. + shift: Number of positions to shift left. Returns: - The rolled array of the same shape as x. + Rolled array with masked positions set to zero. """ - # If shift is 0, it's a no-op. Return the original array. if shift == 0: return x - - # to set the last `abs(shift)` elements of the sequence to zero. return jnp.roll(x, shift, axis=1).at[:, shift:, ...].set(0) -class MultiTokenPredictionLayer(nn.Module): - """ - Implements Multi-Token Prediction (MTP) step: - 1. Normalization of previous hidden state and target token embedding. - 2. Concatenation and Projection of normalized features. - 3. Processing through a Transformer Decoder Layer. - - Equation Representation (Conceptual): - norm_h = RMSNorm(h_prev) - norm_e = RMSNorm(e_target) - h_proj = W_p(concat(norm_h, norm_e)) - h_next = TransformerLayer(h_proj, pos_ids, segment_ids, ...) - - It takes the previous hidden state and target embedding as input and outputs the - processed hidden state from its internal transformer block. - """ +class MultiTokenPredictionLayer(nnx.Module): + """Multi-Token Prediction layer: normalize, concatenate, project, and transform. - config: Config - mesh: Mesh - layer_number: int - transformer_layer_module: Type[DecoderLayer] = DecoderLayer + Implements: h_next = TransformerLayer(W_p(concat(RMSNorm(h_prev), RMSNorm(e_target)))) + """ - @nn.compact - def __call__( + def __init__( self, - prev_hidden_state: jnp.ndarray, - target_token_embedding: jnp.ndarray, - position_ids: jnp.ndarray, - decoder_segment_ids: None | jnp.ndarray, - deterministic: bool, - model_mode: str = MODEL_MODE_TRAIN, - ) -> jnp.ndarray: - """ - Applies the MTP combination, projection, and internal transformer processing. - - Args: - prev_hidden_state: Hidden state from the previous step/layer. - Shape: [batch, seq_len, hidden_size] - target_token_embedding: Embedding of the target token. In the context of MTP, - this often refers to a token at a position relative - to the current step, where the offset is determined - by the layer number `k` (i.e., token t+k). - Shape: [batch, seq_len, embed_dim] - position_ids: Original position IDs for the sequence. - Shape: [batch, seq_len] - decoder_segment_ids: Original segment IDs for the sequence (for attention mask). - Shape: [batch, seq_len] - deterministic: If true, disable dropout. - model_mode: The current operational mode (train, eval, decode). - - Returns: - next_hidden_state: The hidden state produced by this MTP step's internal transformer. - Shape: [batch, seq_len, hidden_size] - """ + config: Config, + mesh: Mesh, + layer_number: int, + transformer_layer_module: Type[DecoderLayer], + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.layer_number = layer_number + self.transformer_layer_module = transformer_layer_module + self.rngs = rngs + k = layer_number cfg = self.config - mesh = self.mesh - k = self.layer_number - # --- 1. Normalize Hidden State and Embedding --- - embedding_norm_layer = rms_norm( - num_features=target_token_embedding.shape[-1], + self.embedding_norm = RMSNorm( + num_features=cfg.base_emb_dim, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name=f"mtp_{k}_embedding_norm", - epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), + rngs=rngs, ) - embedding_norm = embedding_norm_layer(target_token_embedding) - - hidden_state_norm_layer = rms_norm( - num_features=prev_hidden_state.shape[-1], + self.hidden_state_norm = RMSNorm( + num_features=cfg.base_emb_dim, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name=f"mtp_{k}_hidden_state_norm", - epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), + rngs=rngs, ) - - hidden_state_norm = hidden_state_norm_layer(prev_hidden_state) - - # --- 2. Concatenate Normalized Representations --- - # Shape: [B, S, 2*H] - concatenated_features = jnp.concatenate([embedding_norm, hidden_state_norm], axis=-1) - - # --- 3. Project Concatenated Features --- - # Projects from 2*H back down to H - projection_layer = dense_general( - inputs_shape=concatenated_features.shape, + self.projection_layer = DenseGeneral( + in_features_shape=2 * cfg.base_emb_dim, out_features_shape=cfg.base_emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, use_bias=False, kernel_axes=("concat_embed", "embed"), - name=f"mtp_{k}_projection", + rngs=rngs, + ) + # Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically. + mtp_transformer_layer = transformer_layer_module( + config=cfg, + mesh=mesh, + model_mode=MODEL_MODE_TRAIN, + name=f"mtp_{k}_transformer_layer", ) - # Shape: [B, S, H] - projected_features = projection_layer(concatenated_features) + self.transformer_layer = nnx_wrappers.ToNNX(mtp_transformer_layer, rngs=rngs) + + # ToNNX requires explicit initialization with sample inputs for proper parameter setup. + self.transformer_layer.lazy_init( + inputs=jnp.zeros((1, 1, cfg.base_emb_dim), dtype=cfg.dtype), + decoder_segment_ids=None, + decoder_positions=jnp.zeros((1, 1), dtype=jnp.int32), + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + def __call__( + self, + prev_hidden_state: jnp.ndarray, + target_token_embedding: jnp.ndarray, + *, + position_ids: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + deterministic: bool, + model_mode: str = MODEL_MODE_TRAIN, + ) -> jnp.ndarray: + """Applies MTP combination, projection, and transformer processing. - # --- 4. Pass through MTP Transformer Block --- - output = self.transformer_layer_module( - config=cfg, mesh=mesh, model_mode=model_mode, name=f"mtp_{k}_transformer_layer" - )( + Args: + prev_hidden_state: Shape [batch, seq_len, hidden_size]. + target_token_embedding: Embedding for token t+k. Shape [batch, seq_len, embed_dim]. + position_ids: Shape [batch, seq_len]. + decoder_segment_ids: Shape [batch, seq_len] or None. + deterministic: Whether to disable dropout. + model_mode: Operational mode (train, eval, decode). + + Returns: + Processed hidden state. Shape [batch, seq_len, hidden_size]. + """ + embedding_norm = self.embedding_norm(target_token_embedding) + hidden_state_norm = self.hidden_state_norm(prev_hidden_state) + concatenated_features = jnp.concatenate([embedding_norm, hidden_state_norm], axis=-1) + projected_features = self.projection_layer(concatenated_features) + + output = self.transformer_layer( inputs=projected_features, decoder_segment_ids=decoder_segment_ids, decoder_positions=position_ids, @@ -162,27 +163,44 @@ def __call__( model_mode=model_mode, ) - if isinstance(output, tuple): - # Handles the scan=True case, where the output is a tuple. - next_hidden_state = output[0] - else: - # Handles the scan=False case, where the output is a single tensor. - next_hidden_state = output + return output[0] if isinstance(output, tuple) else output - # Shape: [B, S, H] - # --- Return Processed Hidden State --- - return next_hidden_state - -class MultiTokenPredictionBlock(nn.Module): +class MultiTokenPredictionBlock(nnx.Module): """Orchestrates the MTP process by running a sequence of MTP layers.""" - config: Config - mesh: Mesh - transformer_layer_module: Type[DecoderLayer] - decoder: Decoder + def __init__( + self, + config: Config, + mesh: Mesh, + transformer_layer_module: Type[DecoderLayer], + decoder: nnx.Module, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.transformer_layer_module = transformer_layer_module + self.decoder = decoder + self.rngs = rngs if rngs is not None else nnx.Rngs(0) + + # NNX Variables are exposed as Linen mutable collections by ToLinen wrapper. + self.losses = mtp_losses(jnp.zeros((config.mtp_num_layers,), dtype=jnp.float32)) + self.weights = mtp_losses(jnp.zeros((config.mtp_num_layers,), dtype=jnp.float32)) + # Float32 used to avoid gradient errors; converted to int32 in acceptance rate calculation. + self.mtp_preds = mtp_acceptance(jnp.zeros((1,), dtype=jnp.float32)) + self.mtp_mask = mtp_acceptance(jnp.zeros((1,), dtype=jnp.float32)) + + # 1-indexed to match paper convention. + for k in range(1, config.mtp_num_layers + 1): + layer = MultiTokenPredictionLayer( + config=config, + mesh=mesh, + layer_number=k, + transformer_layer_module=transformer_layer_module, + rngs=rngs.fork(), + ) + setattr(self, f"mtp_layer_{k}", layer) - @nn.compact def __call__( self, shared_embedding, @@ -190,126 +208,162 @@ def __call__( input_ids, target_ids, target_mask, + *, position_ids, decoder_segment_ids, model_mode, deterministic, - ): + ) -> dict: cfg = self.config - # The initial hidden state for the MTP chain is the raw output from the main model. mtp_hidden_state = main_hidden_state - # These variables are updated sequentially in each loop iteration, - # moving the prediction window one token to the right each time. + # Rolling variables move prediction window one token to the right per iteration. rolled_input_ids = input_ids rolled_target_ids = target_ids rolled_target_mask = target_mask rolled_position_id = position_ids - # Range chosen to align with the naming convention of the paper + mtp_losses_list = [] + mtp_weights_list = [] + mtp_preds_list = [] + mtp_masks_list = [] + for k in range(1, cfg.mtp_num_layers + 1): - # Sequentially roll all tensors to prepare data for predicting the k-th future token. rolled_input_ids = roll_and_mask(rolled_input_ids) rolled_target_ids = roll_and_mask(rolled_target_ids) rolled_target_mask = roll_and_mask(rolled_target_mask) rolled_position_id = roll_and_mask(rolled_position_id) - # Embed the k-th future input tokens using the shared embedding module target_token_embedding = self.decoder._apply_embedding( - shared_embedding, rolled_input_ids, rolled_position_id, deterministic, self.decoder.model_mode - ) - - # Instantiate and apply the MTP layer for this step - mtp_layer = MultiTokenPredictionLayer( - config=cfg, - mesh=self.mesh, - layer_number=k, - name=f"mtp_layer_{k}", - transformer_layer_module=self.transformer_layer_module, + shared_embedding, + rolled_input_ids, + rolled_position_id, + deterministic, + model_mode=self.decoder.model_mode, ) - next_mtp_hidden_state = mtp_layer( - mtp_hidden_state, - target_token_embedding, - position_ids, - decoder_segment_ids, - deterministic, - self.decoder.model_mode, + mtp_layer = getattr(self, f"mtp_layer_{k}") + mtp_hidden_state = mtp_layer( + prev_hidden_state=mtp_hidden_state, + target_token_embedding=target_token_embedding, + position_ids=position_ids, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=self.decoder.model_mode, ) - # Project to logits using the shared embedding transpose - mtp_logits = self.decoder.apply_output_head(shared_embedding, next_mtp_hidden_state, deterministic, model_mode) + mtp_logits = self.decoder.apply_output_head(shared_embedding, mtp_hidden_state, deterministic, model_mode) - # Calculate cross-entropy loss for this specific layer's prediction mtp_xent, _ = max_utils.cross_entropy_with_logits( mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0 ) mtp_xent_masked = mtp_xent * rolled_target_mask - # This logic doesn't run during model initialization to avoid unwated population of the mutable collections. - if not self.is_initializing(): - # For evaluation, save the top prediction and a valid token mask. - # This is only active for the target layer during an eval run. - if cfg.mtp_eval_target_module == k and self.is_mutable_collection("mtp_acceptance"): - mtp_top_1_pred = jnp.argmax(mtp_logits, axis=-1) - self.sow("mtp_acceptance", "mtp_preds", mtp_top_1_pred) - self.sow("mtp_acceptance", "mtp_mask", rolled_target_mask) - - # For training, save the loss components for this MTP head. - # This is only active during a training run. - if self.is_mutable_collection("mtp_losses"): - self.sow("mtp_losses", "losses", jnp.sum(mtp_xent_masked)) - self.sow("mtp_losses", "weights", jnp.sum(rolled_target_mask)) + if model_mode == MODEL_MODE_TRAIN: + mtp_losses_list.append(jnp.sum(mtp_xent_masked)) + mtp_weights_list.append(jnp.sum(rolled_target_mask).astype(jnp.float32)) - # The output of this layer is the input for the next, maintaining the causal chain. - mtp_hidden_state = next_mtp_hidden_state + if cfg.mtp_eval_target_module == k: + # Float32 to avoid gradient errors; converted back to int32 in acceptance calculation. + mtp_preds_list.append(jnp.argmax(mtp_logits, axis=-1).astype(jnp.float32)) + mtp_masks_list.append(rolled_target_mask) + if mtp_losses_list: + self.losses.value = jnp.stack(mtp_losses_list) + self.weights.value = jnp.stack(mtp_weights_list) + if mtp_preds_list: + self.mtp_preds.value = jnp.stack(mtp_preds_list) + self.mtp_mask.value = jnp.stack(mtp_masks_list) -def calculate_mtp_loss(intermediate_outputs, config): - """Calculates the Multi Token Prediction loss from intermediate outputs.""" - losses_path = ("mtp_losses", "mtp_block", "losses") - weights_path = ("mtp_losses", "mtp_block", "weights") + return {} - mtp_losses = maxtext_utils.get_nested_value(intermediate_outputs, losses_path, default=()) - mtp_weights = maxtext_utils.get_nested_value(intermediate_outputs, weights_path, default=()) - if not mtp_losses: # MTP heads did not run +def calculate_mtp_loss(intermediate_outputs, config): + """Calculates Multi-Token Prediction loss from intermediate outputs.""" + mtp_losses_data = maxtext_utils.get_nested_value( + intermediate_outputs, ("mtp_losses", "mtp_block", "losses"), default=None + ) + mtp_weights_data = maxtext_utils.get_nested_value( + intermediate_outputs, ("mtp_losses", "mtp_block", "weights"), default=None + ) + + if mtp_losses_data is None: return 0.0 - sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses)) - sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights)) + # Handle both tuple (Linen sow) and array (NNX Variable) formats. + if isinstance(mtp_losses_data, (tuple, list)): + if not mtp_losses_data: + return 0.0 + mtp_losses_array = jnp.array(mtp_losses_data) + mtp_weights_array = jnp.array(mtp_weights_data) + else: + if mtp_losses_data.size == 0: + return 0.0 + mtp_losses_array = mtp_losses_data + mtp_weights_array = mtp_weights_data - avg_mtp_loss = sum_of_all_mtp_losses / (sum_of_all_mtp_weights + EPS) - scaled_mtp_loss = avg_mtp_loss * config.mtp_loss_scaling_factor - return scaled_mtp_loss + avg_mtp_loss = jnp.sum(mtp_losses_array) / (jnp.sum(mtp_weights_array) + EPS) + return avg_mtp_loss * config.mtp_loss_scaling_factor def calculate_mtp_acceptance_rate(intermediate_outputs, config): - """Calculates the MTP acceptance rate from intermediate outputs.""" - + """Calculates MTP acceptance rate from intermediate outputs.""" sown_data = maxtext_utils.get_nested_value(intermediate_outputs, ("mtp_acceptance", "mtp_block"), {}) - mtp_preds = maxtext_utils.get_nested_value(sown_data, ("mtp_preds",), [None])[0] - valid_mask = maxtext_utils.get_nested_value(sown_data, ("mtp_mask",), [None])[0] - # These values are only "sown" (saved) during an evaluation run and only for the specific - # MTP layer specified by `config.mtp_eval_target_module`. This check handles cases - # where the required data is absent (e.g., during a training step) and prevents errors. + # Handle both tuple (Linen sow) and array (NNX Variable) formats. + mtp_preds_raw = maxtext_utils.get_nested_value(sown_data, ("mtp_preds",), None) + valid_mask_raw = maxtext_utils.get_nested_value(sown_data, ("mtp_mask",), None) + + mtp_preds = mtp_preds_raw[0] if isinstance(mtp_preds_raw, (tuple, list)) and mtp_preds_raw else mtp_preds_raw + valid_mask = valid_mask_raw[0] if isinstance(valid_mask_raw, (tuple, list)) and valid_mask_raw else valid_mask_raw + + # Only populated during eval for the target MTP module. if mtp_preds is None or valid_mask is None: return 0.0 - # Get the main model's greedy predictions from the logits. + mtp_preds = mtp_preds.astype(jnp.int32) main_model_preds = jnp.argmax(intermediate_outputs["logits"], axis=-1) - # Roll the main model's predictions to align them in time with the MTP head's target. + # Align main model predictions with MTP head target by rolling k steps. rolled_main_preds = main_model_preds for _ in range(config.mtp_eval_target_module): rolled_main_preds = roll_and_mask(rolled_main_preds) - # Compare the aligned predictions. The `valid_mask` ensures that the comparison - # only happens on valid tokens, ignoring the placeholder values introduced at the - # end of the sequence by the `roll_and_mask` operation. correct_predictions = jnp.sum((mtp_preds == rolled_main_preds) * valid_mask) total_valid_tokens = jnp.sum(valid_mask) - # Return acceptance rate as a percentage return (correct_predictions / (total_valid_tokens + EPS)) * 100 + + +def multi_token_prediction_block_as_linen( + *, + config: Config, + mesh: Mesh, + transformer_layer_module: Type[DecoderLayer], + decoder: nnx.Module, + rngs: nnx.Rngs, + name: str | None = None, +) -> nn.Module: + """Initializes MultiTokenPredictionBlock as a Linen module. + + Args: + config: Configuration object containing model hyperparameters. + mesh: JAX Mesh for model parallelism. + transformer_layer_module: The Transformer Decoder Layer class to use. + decoder: The decoder module that provides embedding and output head. + rngs: Random number generators for initialization. + name: Optional name for the module. + + Returns: + An instance of MultiTokenPredictionBlock wrapped as a Linen module. + """ + return nnx.bridge.to_linen( + MultiTokenPredictionBlock, + config=config, + mesh=mesh, + transformer_layer_module=transformer_layer_module, + decoder=decoder, + rngs=rngs, + metadata_fn=variable_to_logically_partitioned, + name=name, + ) diff --git a/tests/multi_token_prediction_test.py b/tests/multi_token_prediction_test.py index cdba6cd13..e02763415 100644 --- a/tests/multi_token_prediction_test.py +++ b/tests/multi_token_prediction_test.py @@ -25,11 +25,10 @@ from MaxText import max_logging, pyconfig from MaxText import maxtext_utils from MaxText.globals import MAXTEXT_PKG_DIR -from MaxText.layers.decoders import Decoder, DecoderLayer +from MaxText.layers.decoders import DecoderLayer from MaxText.layers import multi_token_prediction # The class under test from MaxText.layers import embeddings from MaxText.common_types import MODEL_MODE_TRAIN -from MaxText.layers import nnx_wrappers TEST_LAYER_NUM = 1 @@ -43,8 +42,10 @@ def setUp(self): [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], run_name="multi_token_prediction_layer_test", skip_jax_distributed_system=True, + per_device_batch_size=8, ) self.rng = jax.random.PRNGKey(42) # Base RNG for setup + self.rngs = nnx.Rngs(params=self.rng, dropout=self.rng) devices_array = maxtext_utils.create_device_mesh(self.cfg) self.mesh = Mesh(devices_array, self.cfg.mesh_axes) @@ -54,6 +55,7 @@ def setUp(self): mesh=self.mesh, layer_number=TEST_LAYER_NUM, transformer_layer_module=DecoderLayer, + rngs=self.rngs, ) # Dimensions directly from the config object @@ -64,36 +66,25 @@ def setUp(self): # Prepare Dummy Input Data prev_hidden_state_shape = (self.batch_size, self.seq_len, self.embed_dim) target_embedding_shape = (self.batch_size, self.seq_len, self.embed_dim) - data_rng1, data_rng2, init_rng = jax.random.split(self.rng, 3) + data_rng1, data_rng2, _ = jax.random.split(self.rng, 3) self.prev_hidden_state = jax.random.normal(data_rng1, prev_hidden_state_shape, dtype=self.cfg.dtype) self.target_token_embedding = jax.random.normal(data_rng2, target_embedding_shape, dtype=self.cfg.dtype) self.position_ids = jnp.arange(self.seq_len, dtype=jnp.int32).reshape(1, -1).repeat(self.batch_size, axis=0) # Simulate a simple case with no padding. self.decoder_segment_ids = jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32) - - # Initialize Layer Parameters - init_rngs = {"params": init_rng, "dropout": init_rng} - self.variables = self.mtp_layer.init( - init_rngs, - self.prev_hidden_state, - self.target_token_embedding, - self.position_ids, - self.decoder_segment_ids, - deterministic=True, - ) max_logging.log("Setup complete.") def test_multi_token_prediction_layer_output(self): """Tests the basic forward pass and output shape of MultiTokenPredictionLayer.""" - output_hidden_state = self.mtp_layer.apply( - self.variables, + output_hidden_state = self.mtp_layer( self.prev_hidden_state, self.target_token_embedding, - self.position_ids, + position_ids=self.position_ids, decoder_segment_ids=self.decoder_segment_ids, deterministic=True, + model_mode=MODEL_MODE_TRAIN, ) # Assertions using unittest methods expected_output_shape = (self.batch_size, self.seq_len, self.embed_dim) @@ -125,32 +116,46 @@ def test_multi_token_prediction_layer_output(self): class MTPBlockTestModel(nnx.Module): """A lightweight wrapper model for testing the MTPBlock.""" - def __init__( - self, - config: Config, - mesh: Mesh, - rngs: nnx.Rngs | None = None, - ): + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs): + """Initializes the MTP block and its dependencies for the test.""" self.config = config self.mesh = mesh - """Initializes the MTP block and its dependencies for the test.""" - self.shared_embedding = embeddings.Embed( + self.rngs = rngs if rngs is not None else nnx.Rngs(0) + self._shared_embedding = embeddings.Embed( num_embeddings=self.config.vocab_size, num_features=self.config.base_emb_dim, config=self.config, mesh=self.mesh, - rngs=rngs, + rngs=self.rngs, ) - decoder_for_mtp = Decoder(config=self.config, mesh=self.mesh, name="decoder_for_mtp") - self.multi_token_prediction_block = multi_token_prediction.MultiTokenPredictionBlock( + class MockDecoderForMTP: + """A mock decoder that simulates the behavior needed by MTPBlock.""" + + def __init__(self, config: Config): + self.config = config + self.model_mode = MODEL_MODE_TRAIN + + def _apply_embedding(self, _shared_embedding, input_ids, _position_ids, _deterministic, model_mode): + """Returns a zero tensor with the correct embedding shape.""" + batch_size, seq_len = input_ids.shape + embed_dim = self.config.base_emb_dim + return jnp.zeros((batch_size, seq_len, embed_dim), dtype=self.config.dtype) + + def apply_output_head(self, _shared_embedding, hidden_state, _deterministic, model_mode): + """Returns a zero tensor with the correct logit shape.""" + batch_size, seq_len, _ = hidden_state.shape + return jnp.zeros((batch_size, seq_len, self.config.vocab_size), dtype=self.config.dtype) + + self.decoder = MockDecoderForMTP(config=self.config) + + self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, - name="mtp_block", transformer_layer_module=DecoderLayer, - decoder=decoder_for_mtp, + decoder=self.decoder, + rngs=self.rngs, ) - self.mtp_block = nnx_wrappers.ToNNX(self.multi_token_prediction_block, rngs=nnx.Rngs(params=0)) def __call__( self, @@ -158,6 +163,7 @@ def __call__( input_ids, target_ids, target_mask, + *, position_ids, decoder_segment_ids, model_mode, @@ -165,18 +171,21 @@ def __call__( mutable=None, ): return self.mtp_block( - self.shared_embedding, + self._shared_embedding, main_hidden_state, input_ids, target_ids, target_mask, - position_ids, - decoder_segment_ids, - model_mode, - deterministic, - mutable=mutable, + position_ids=position_ids, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + deterministic=deterministic, ) + def shared_embedding(self): + """Returns the shared embedding.""" + return self._shared_embedding + class MultiTokenPredictionBlockTest(unittest.TestCase): """Unit tests for the MultiTokenPredictionBlock.""" @@ -188,79 +197,102 @@ def setUp(self): run_name="mtp_block_test", skip_jax_distributed_system=True, mtp_num_layers=2, + base_emb_dim=16, ) self.nnx_rngs = nnx.Rngs(params=0) self.rng = jax.random.PRNGKey(43) + self.rngs = nnx.Rngs(params=self.rng, dropout=self.rng) devices_array = maxtext_utils.create_device_mesh(self.cfg) self.mesh = Mesh(devices_array, self.cfg.mesh_axes) data_rng, self.init_rng = jax.random.split(self.rng) - self.batch_size, self.seq_len, self.embed_dim = 2, 8, 16 + self.batch_size, self.seq_len, self.embed_dim = 2, 8, self.cfg.base_emb_dim key1, key2, key3 = jax.random.split(data_rng, 3) self.main_hidden_state = jax.random.normal(key1, (self.batch_size, self.seq_len, self.embed_dim)) self.input_ids = jax.random.randint(key2, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size) self.target_ids = jax.random.randint(key3, (self.batch_size, self.seq_len), 0, self.cfg.vocab_size) self.target_mask = jnp.ones_like(self.target_ids) - self.position_ids = jnp.arange(self.seq_len, dtype=jnp.int32).reshape(1, -1) + self.position_ids = jnp.arange(self.seq_len, dtype=jnp.int32).reshape(1, -1).repeat(self.batch_size, axis=0) self.decoder_segment_ids = jnp.ones((self.batch_size, self.seq_len), dtype=jnp.int32) - self.test_model = MTPBlockTestModel(config=self.cfg, mesh=self.mesh, rngs=self.nnx_rngs) + self.test_model = MTPBlockTestModel( + config=self.cfg, + mesh=self.mesh, + rngs=self.rngs, + ) + + def test_no_sow_during_init(self): + """Verifies losses/weights are initialized with zeros (NNX behavior).""" + # NNX pre-initializes Variables with zeros to avoid checkpointing errors. + # Unlike Linen which sows during forward pass, NNX creates Variables in __init__. + initial_state = nnx.state(self.test_model) + self.assertTrue(hasattr(initial_state.mtp_block, "losses")) + self.assertTrue(hasattr(initial_state.mtp_block, "weights")) + + # Verify they're initialized with zeros of correct shape. + losses_val = initial_state.mtp_block.losses.value + weights_val = initial_state.mtp_block.weights.value + self.assertEqual(losses_val.shape, (self.cfg.mtp_num_layers,)) + self.assertEqual(weights_val.shape, (self.cfg.mtp_num_layers,)) + self.assertTrue(jnp.all(losses_val == 0.0)) + self.assertTrue(jnp.all(weights_val == 0.0)) def test_sow_functionality(self): """Verifies that the block correctly sows losses and weights.""" - self.test_model( - self.main_hidden_state, - self.input_ids, - self.target_ids, - self.target_mask, - self.position_ids, - self.decoder_segment_ids, - deterministic=True, + _ = self.test_model( + main_hidden_state=self.main_hidden_state, + input_ids=self.input_ids, + target_ids=self.target_ids, + target_mask=self.target_mask, + position_ids=self.position_ids, + decoder_segment_ids=self.decoder_segment_ids, model_mode=MODEL_MODE_TRAIN, - mutable=["mtp_losses"], + deterministic=True, ) - self.assertTrue(hasattr(self.test_model.mtp_block, "losses")) - mtp_loss = self.test_model.mtp_block.losses - self.assertTrue(type(mtp_loss).__name__, "mtp_losses") - self.assertEqual(len(mtp_loss), self.cfg.mtp_num_layers) + state = nnx.state(self.test_model) - def test_no_sow_during_init(self): - """Verifies no losses are sown during model initialization.""" - # `self.variables` was created by `.init()`. We inspect it to ensure - # our `if not self.is_initializing()` check worked. - self.assertFalse(hasattr(self.test_model.mtp_block, "losses")) + # Check for the existence of the 'losses' and 'weights' attributes. + self.assertTrue(hasattr(state.mtp_block, "losses")) + self.assertTrue(hasattr(state.mtp_block, "weights")) + + # Access the actual data tuple inside the .value attribute. + losses_val = state.mtp_block.losses.value + weights_val = state.mtp_block.weights.value + + self.assertEqual(len(losses_val), self.cfg.mtp_num_layers) + self.assertEqual(len(weights_val), self.cfg.mtp_num_layers) def test_loss_aggregation_logic(self): """ Tests the full 'sow and reap' cycle, mimicking the logic from train.py to ensure the final loss calculation is correct. """ - # 1. Run the forward pass and capture the sown variables. - self.test_model( - self.main_hidden_state, - self.input_ids, - self.target_ids, - self.target_mask, - self.position_ids, - self.decoder_segment_ids, - deterministic=False, - mutable=["mtp_losses"], + # Run the forward pass and capture the sown variables. + _ = self.test_model( + main_hidden_state=self.main_hidden_state, + input_ids=self.input_ids, + target_ids=self.target_ids, + target_mask=self.target_mask, + position_ids=self.position_ids, + decoder_segment_ids=self.decoder_segment_ids, model_mode=MODEL_MODE_TRAIN, + deterministic=False, ) + state = nnx.state(self.test_model) # This section of the test now *becomes* the logic from train.py # ------------------------------------------------------------- final_loss_for_gradient = 100.0 # A dummy main loss mtp_loss_for_logging = 0.0 - # 2. Get the weight and losses. - mtp_losses = self.test_model.mtp_block.losses.value - mtp_weights = self.test_model.mtp_block.weights.value + # Use the standard utility to get the data. + mtp_losses_var = getattr(state.mtp_block, "losses", None) + mtp_weights_var = getattr(state.mtp_block, "weights", None) - # 3. Perform the aggregation logic exactly as in `loss_fn`. - if mtp_losses: - sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses)).item() - sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights)).item() + # Perform the aggregation logic exactly as in `loss_fn`. + if mtp_losses_var and mtp_weights_var: + sum_of_all_mtp_losses = jnp.sum(jnp.array(mtp_losses_var.value)) + sum_of_all_mtp_weights = jnp.sum(jnp.array(mtp_weights_var.value)) self.assertGreater(sum_of_all_mtp_weights, 0) @@ -271,7 +303,7 @@ def test_loss_aggregation_logic(self): mtp_loss_for_logging = scaled_mtp_loss # ------------------------------------------------------------- - # 4. Assert that the final values are correct. + # Assert that the final values are correct. # The final loss should have increased from its base value. self.assertGreater(final_loss_for_gradient, 100.0) # The logged MTP loss should be a valid, positive number. @@ -312,8 +344,15 @@ def test_mtp_roll_and_mask_shapes(self): dtype=jnp.int32, ) - self.assertEqual(rolled_by_1.shape, (batch_size, seq_len), "Shape should be preserved after rolling.") - self.assertTrue(jnp.array_equal(rolled_by_1, expected_1), "Array content is incorrect after shift by -1.") + self.assertEqual( + rolled_by_1.shape, + (batch_size, seq_len), + "Shape should be preserved after rolling.", + ) + self.assertTrue( + jnp.array_equal(rolled_by_1, expected_1), + "Array content is incorrect after shift by -1.", + ) # --- Test Case 2: Larger left shift by 3 --- # This simulates a later step in a hypothetical MTP loop. @@ -329,8 +368,15 @@ def test_mtp_roll_and_mask_shapes(self): ], dtype=jnp.int32, ) - self.assertEqual(rolled_by_3.shape, (batch_size, seq_len), "Shape should be preserved after rolling.") - self.assertTrue(jnp.array_equal(rolled_by_3, expected_3), "Array content is incorrect after shift by -3.") + self.assertEqual( + rolled_by_3.shape, + (batch_size, seq_len), + "Shape should be preserved after rolling.", + ) + self.assertTrue( + jnp.array_equal(rolled_by_3, expected_3), + "Array content is incorrect after shift by -3.", + ) # --- Test Case 3: Shift of 0 (edge case) --- # This should result in no change to the tensor.