Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def __call__(
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")

Expand Down Expand Up @@ -240,6 +243,10 @@ def __call__(
logical_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")

# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")

Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/deepseek_batchsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __call__(
kv_cache=None,
attention_metadata=None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
x = self.with_logical_constraint(inputs)
x = jax.ad_checkpoint.checkpoint_name(x, "decoder_layer_input")

Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def __call__(
kv_cache=None,
attention_metadata=None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def __call__(
kv_cache=None,
attention_metadata=None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __call__(
attention_metadata=None,
):
cfg = self.config
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")

Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def __call__(
):
cfg = self.config
mesh = self.mesh
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def __call__(
attention_metadata=None,
):
cfg = self.config
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def __call__(
):
cfg = self.config

# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.activation_axis_names))
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,9 @@ def __call__(
cfg = self.config
assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`."

# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")

Expand Down
4 changes: 3 additions & 1 deletion src/MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,11 @@ def __call__(
kv_cache=None,
attention_metadata=None,
):

cfg = self.config

# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx = self.pre_self_attention_layer_norm(inputs)
Expand Down
4 changes: 3 additions & 1 deletion src/MaxText/layers/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def __call__(
kv_cache=None,
attention_metadata=None,
):

# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")

Expand Down
13 changes: 12 additions & 1 deletion src/MaxText/layers/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,9 @@ def __call__(
# Loop over the number of sub-layers that make up one repeating pattern.
for i in range(cfg.inhomogeneous_layer_cycle_interval):
layer = getattr(self, f"layer_{i}")
x = layer(
# The second return value is kv_cache, which we ignore here because
# it is not passed as a carry in scannable layers.
x, _ = layer(
x,
decoder_segment_ids,
decoder_positions,
Expand Down Expand Up @@ -802,6 +804,9 @@ def __call__(
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
residual = inputs

# First LayerNorm, applied before the attention block.
Expand Down Expand Up @@ -1001,6 +1006,9 @@ def __call__(
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm(
inputs,
decoder_segment_ids,
Expand Down Expand Up @@ -1065,6 +1073,9 @@ def __call__(
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm(
inputs,
decoder_segment_ids,
Expand Down
6 changes: 6 additions & 0 deletions src/MaxText/layers/simple_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __init__(
def __call__(
self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, previous_chunk=None, page_state=None
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
if self.config.scan_layers:
return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding), None
return jnp.dot(inputs, self.weights.astype(inputs.dtype), out_sharding=self.out_sharding)
Expand Down Expand Up @@ -124,6 +127,9 @@ def __call__(
page_state=None,
slot=0,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
intermediate = jnp.dot(inputs, self.ff_1.astype(inputs.dtype), out_sharding=self.mlp_sharding)
output = jnp.dot(intermediate, self.ff_2.astype(inputs.dtype), out_sharding=self.activation_sharding)
if self.config.scan_layers:
Expand Down
Loading