diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index 5eb91fc5f..2db5690db 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -171,6 +171,9 @@ def __call__( logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -240,6 +243,10 @@ def __call__( logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index 5917b522c..fe9738130 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -61,6 +61,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] x = self.with_logical_constraint(inputs) x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input") diff --git a/src/MaxText/layers/gemma.py b/src/MaxText/layers/gemma.py index dcd237162..8304d3047 100644 --- a/src/MaxText/layers/gemma.py +++ b/src/MaxText/layers/gemma.py @@ -132,6 +132,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] diff --git a/src/MaxText/layers/gemma2.py b/src/MaxText/layers/gemma2.py index 3d0d39efe..116983073 100644 --- a/src/MaxText/layers/gemma2.py +++ b/src/MaxText/layers/gemma2.py @@ -226,6 +226,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] diff --git a/src/MaxText/layers/gemma3.py b/src/MaxText/layers/gemma3.py index 3ecc8991c..1906af5aa 100644 --- a/src/MaxText/layers/gemma3.py +++ b/src/MaxText/layers/gemma3.py @@ -193,6 +193,9 @@ def __call__( attention_metadata=None, ): cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/gpt3.py b/src/MaxText/layers/gpt3.py index 1bec2bb26..8876a06dc 100644 --- a/src/MaxText/layers/gpt3.py +++ b/src/MaxText/layers/gpt3.py @@ -430,6 +430,10 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_norm(inputs) diff --git a/src/MaxText/layers/gpt_oss.py b/src/MaxText/layers/gpt_oss.py index 1301a46b9..111635910 100644 --- a/src/MaxText/layers/gpt_oss.py +++ b/src/MaxText/layers/gpt_oss.py @@ -149,6 +149,9 @@ def __call__( attention_metadata=None, ): cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/llama2.py b/src/MaxText/layers/llama2.py index 7148b1a5b..3b1310655 100644 --- a/src/MaxText/layers/llama2.py +++ b/src/MaxText/layers/llama2.py @@ -152,6 +152,9 @@ def __call__( ): cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.activation_axis_names)) diff --git a/src/MaxText/layers/llama4.py b/src/MaxText/layers/llama4.py index 9db76ae5b..04828881f 100644 --- a/src/MaxText/layers/llama4.py +++ b/src/MaxText/layers/llama4.py @@ -454,6 +454,9 @@ def __call__( cfg = self.config assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`." + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/mistral.py b/src/MaxText/layers/mistral.py index 643fecaae..2ba140644 100644 --- a/src/MaxText/layers/mistral.py +++ b/src/MaxText/layers/mistral.py @@ -135,9 +135,11 @@ def __call__( kv_cache=None, attention_metadata=None, ): - cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_layer_norm(inputs) diff --git a/src/MaxText/layers/mixtral.py b/src/MaxText/layers/mixtral.py index 8d23e72d3..d90c3015a 100644 --- a/src/MaxText/layers/mixtral.py +++ b/src/MaxText/layers/mixtral.py @@ -140,7 +140,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): - + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index 5898d422f..62eb67ea4 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -708,7 +708,9 @@ def __call__( # Loop over the number of sub-layers that make up one repeating pattern. for i in range(cfg.inhomogeneous_layer_cycle_interval): layer = getattr(self, f"layer_{i}") - x = layer( + # The second return value is kv_cache, which we ignore here because + # it is not passed as a carry in scannable layers. + x, _ = layer( x, decoder_segment_ids, decoder_positions, @@ -802,6 +804,9 @@ def __call__( kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] residual = inputs # First LayerNorm, applied before the attention block. @@ -1001,6 +1006,9 @@ def __call__( kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( inputs, decoder_segment_ids, @@ -1065,6 +1073,9 @@ def __call__( kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( inputs, decoder_segment_ids, diff --git a/src/MaxText/layers/simple_layer.py b/src/MaxText/layers/simple_layer.py index fdfbc07d1..d695ccc60 100644 --- a/src/MaxText/layers/simple_layer.py +++ b/src/MaxText/layers/simple_layer.py @@ -61,6 +61,9 @@ def __init__( def __call__( self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, previous_chunk=None, page_state=None ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] if self.config.scan_layers: return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding), None return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding) @@ -124,6 +127,9 @@ def __call__( page_state=None, slot=0, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] intermediate = jnp.dot(inputs, self.ff_1.astype(inputs.dtype), out_sharding=self.mlp_sharding) output = jnp.dot(intermediate, self.ff_2.astype(inputs.dtype), out_sharding=self.activation_sharding) if self.config.scan_layers: diff --git a/tests/train_smoke_test.py b/tests/train_smoke_test.py index 1ad53afa7..b839232e6 100644 --- a/tests/train_smoke_test.py +++ b/tests/train_smoke_test.py @@ -53,6 +53,35 @@ def test_tiny_config(self): ] ) + def test_tiny_config_no_scan(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + # pylint: disable=f-string-without-interpolation + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "scan_layers=False", + ] + ) + def test_tiny_config_explicit_shardmode(self): test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable train_main(