Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from MaxText.layers.embeddings import Embed, embed_as_linen
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock
from MaxText.sharding import all_gather_over_fsdp
from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen
from MaxText.maxtext_utils import all_gather_over_fsdp

# ------------------------------------------------------------------------------
# The network: Transformer Definitions
Expand Down Expand Up @@ -94,8 +94,12 @@ def setup(self):
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
mtp_layer = layer_types[-1]
self.mtp_block = MultiTokenPredictionBlock(
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder
self.mtp_block = multi_token_prediction_block_as_linen(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer,
decoder=self.decoder,
rngs=self.make_rng("mtp_block"),
)

def logits_from_hidden_states(self, hidden_states, deterministic, model_mode):
Expand Down Expand Up @@ -285,7 +289,15 @@ class Transformer(nnx.Module):
# Make new attributes required, so that all Transformer dependencies (train, decode,
# compile, etc) will error instead of silently use defaults.
# pylint: disable=attribute-defined-outside-init
def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs):
def __init__(
self,
config: Config,
mesh: Mesh,
quant: Quant,
*,
model_mode: str = MODEL_MODE_TRAIN,
rngs: nnx.Rngs,
):
"""Initialize shared_embedding & decoder layers."""
self.config = config
self.mesh = mesh
Expand Down Expand Up @@ -347,8 +359,13 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
mtp_layer = layer_types[-1]
mtp_block_linen = MultiTokenPredictionBlock(
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder
mtp_block_linen = multi_token_prediction_block_as_linen(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer,
decoder=self.decoder,
rngs=rngs,
name="mtp_block",
)
self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs)

Expand Down Expand Up @@ -593,7 +610,10 @@ def __call__(
page_state=page_state,
)
all_model_weights = all_gather_over_fsdp(
self.model.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules
self.model.variables,
partition_spec,
mesh=self.mesh,
logical_axis_rules=self.config.logical_axis_rules,
)

return self.model.apply(
Expand Down
Loading
Loading