Skip to content

Commit a8148db

Browse files
Migrate multi_token_prediction to NNX
1 parent 4f6cdf4 commit a8148db

File tree

3 files changed

+387
-266
lines changed

3 files changed

+387
-266
lines changed

src/MaxText/layers/models.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
from MaxText.layers.embeddings import Embed, embed_as_linen
3535
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen
3636
from MaxText.layers.quantizations import AqtQuantization as Quant
37-
from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock
38-
from MaxText.sharding import all_gather_over_fsdp
37+
from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen
38+
from MaxText.maxtext_utils import all_gather_over_fsdp
3939

4040
# ------------------------------------------------------------------------------
4141
# The network: Transformer Definitions
@@ -94,8 +94,12 @@ def setup(self):
9494
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
9595
# By convention, this is the last layer in the list.
9696
mtp_layer = layer_types[-1]
97-
self.mtp_block = MultiTokenPredictionBlock(
98-
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder
97+
self.mtp_block = multi_token_prediction_block_as_linen(
98+
config=self.config,
99+
mesh=self.mesh,
100+
transformer_layer_module=mtp_layer,
101+
decoder=self.decoder,
102+
rngs=self.make_rng("mtp_block"),
99103
)
100104

101105
def logits_from_hidden_states(self, hidden_states, deterministic, model_mode):
@@ -285,7 +289,15 @@ class Transformer(nnx.Module):
285289
# Make new attributes required, so that all Transformer dependencies (train, decode,
286290
# compile, etc) will error instead of silently use defaults.
287291
# pylint: disable=attribute-defined-outside-init
288-
def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs):
292+
def __init__(
293+
self,
294+
config: Config,
295+
mesh: Mesh,
296+
quant: Quant,
297+
*,
298+
model_mode: str = MODEL_MODE_TRAIN,
299+
rngs: nnx.Rngs,
300+
):
289301
"""Initialize shared_embedding & decoder layers."""
290302
self.config = config
291303
self.mesh = mesh
@@ -347,8 +359,13 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str
347359
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
348360
# By convention, this is the last layer in the list.
349361
mtp_layer = layer_types[-1]
350-
mtp_block_linen = MultiTokenPredictionBlock(
351-
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder
362+
mtp_block_linen = multi_token_prediction_block_as_linen(
363+
config=self.config,
364+
mesh=self.mesh,
365+
transformer_layer_module=mtp_layer,
366+
decoder=self.decoder,
367+
rngs=rngs,
368+
name="mtp_block",
352369
)
353370
self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs)
354371

@@ -593,7 +610,10 @@ def __call__(
593610
page_state=page_state,
594611
)
595612
all_model_weights = all_gather_over_fsdp(
596-
self.model.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules
613+
self.model.variables,
614+
partition_spec,
615+
mesh=self.mesh,
616+
logical_axis_rules=self.config.logical_axis_rules,
597617
)
598618

599619
return self.model.apply(

0 commit comments

Comments
 (0)