diff --git a/src/MaxText/layers/deepseek.py b/src/MaxText/layers/deepseek.py index 5eb91fc5f..2db5690db 100644 --- a/src/MaxText/layers/deepseek.py +++ b/src/MaxText/layers/deepseek.py @@ -171,6 +171,9 @@ def __call__( logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -240,6 +243,10 @@ def __call__( logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/deepseek_batchsplit.py b/src/MaxText/layers/deepseek_batchsplit.py index 5917b522c..fe9738130 100644 --- a/src/MaxText/layers/deepseek_batchsplit.py +++ b/src/MaxText/layers/deepseek_batchsplit.py @@ -61,6 +61,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] x = self.with_logical_constraint(inputs) x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input") diff --git a/src/MaxText/layers/gemma.py b/src/MaxText/layers/gemma.py index dcd237162..8304d3047 100644 --- a/src/MaxText/layers/gemma.py +++ b/src/MaxText/layers/gemma.py @@ -132,6 +132,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] diff --git a/src/MaxText/layers/gemma2.py b/src/MaxText/layers/gemma2.py index 3d0d39efe..116983073 100644 --- a/src/MaxText/layers/gemma2.py +++ b/src/MaxText/layers/gemma2.py @@ -226,6 +226,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] diff --git a/src/MaxText/layers/gemma3.py b/src/MaxText/layers/gemma3.py index 3ecc8991c..1906af5aa 100644 --- a/src/MaxText/layers/gemma3.py +++ b/src/MaxText/layers/gemma3.py @@ -193,6 +193,9 @@ def __call__( attention_metadata=None, ): cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/gpt3.py b/src/MaxText/layers/gpt3.py index 831677583..c7c2e1835 100644 --- a/src/MaxText/layers/gpt3.py +++ b/src/MaxText/layers/gpt3.py @@ -346,6 +346,9 @@ def __call__( ): cfg = self.config mesh = self.mesh + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/gpt_oss.py b/src/MaxText/layers/gpt_oss.py index 1301a46b9..111635910 100644 --- a/src/MaxText/layers/gpt_oss.py +++ b/src/MaxText/layers/gpt_oss.py @@ -149,6 +149,9 @@ def __call__( attention_metadata=None, ): cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/llama2.py b/src/MaxText/layers/llama2.py index 7148b1a5b..3b1310655 100644 --- a/src/MaxText/layers/llama2.py +++ b/src/MaxText/layers/llama2.py @@ -152,6 +152,9 @@ def __call__( ): cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.activation_axis_names)) diff --git a/src/MaxText/layers/llama4.py b/src/MaxText/layers/llama4.py index 9db76ae5b..04828881f 100644 --- a/src/MaxText/layers/llama4.py +++ b/src/MaxText/layers/llama4.py @@ -454,6 +454,9 @@ def __call__( cfg = self.config assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`." + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/mistral.py b/src/MaxText/layers/mistral.py index 643fecaae..2ba140644 100644 --- a/src/MaxText/layers/mistral.py +++ b/src/MaxText/layers/mistral.py @@ -135,9 +135,11 @@ def __call__( kv_cache=None, attention_metadata=None, ): - cfg = self.config + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_layer_norm(inputs) diff --git a/src/MaxText/layers/mixtral.py b/src/MaxText/layers/mixtral.py index 8d23e72d3..d90c3015a 100644 --- a/src/MaxText/layers/mixtral.py +++ b/src/MaxText/layers/mixtral.py @@ -140,7 +140,9 @@ def __call__( kv_cache=None, attention_metadata=None, ): - + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index 5898d422f..62eb67ea4 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -708,7 +708,9 @@ def __call__( # Loop over the number of sub-layers that make up one repeating pattern. for i in range(cfg.inhomogeneous_layer_cycle_interval): layer = getattr(self, f"layer_{i}") - x = layer( + # The second return value is kv_cache, which we ignore here because + # it is not passed as a carry in scannable layers. + x, _ = layer( x, decoder_segment_ids, decoder_positions, @@ -802,6 +804,9 @@ def __call__( kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] residual = inputs # First LayerNorm, applied before the attention block. @@ -1001,6 +1006,9 @@ def __call__( kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( inputs, decoder_segment_ids, @@ -1065,6 +1073,9 @@ def __call__( kv_cache: None | jnp.ndarray = None, attention_metadata: None | dict[str, Any] = None, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( inputs, decoder_segment_ids, diff --git a/src/MaxText/layers/simple_layer.py b/src/MaxText/layers/simple_layer.py index fdfbc07d1..d695ccc60 100644 --- a/src/MaxText/layers/simple_layer.py +++ b/src/MaxText/layers/simple_layer.py @@ -61,6 +61,9 @@ def __init__( def __call__( self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, previous_chunk=None, page_state=None ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] if self.config.scan_layers: return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding), None return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding) @@ -124,6 +127,9 @@ def __call__( page_state=None, slot=0, ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] intermediate = jnp.dot(inputs, self.ff_1.astype(inputs.dtype), out_sharding=self.mlp_sharding) output = jnp.dot(intermediate, self.ff_2.astype(inputs.dtype), out_sharding=self.activation_sharding) if self.config.scan_layers: