|
34 | 34 | from MaxText.layers.embeddings import Embed, embed_as_linen |
35 | 35 | from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen |
36 | 36 | from MaxText.layers.quantizations import AqtQuantization as Quant |
37 | | -from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock |
38 | | -from MaxText.sharding import all_gather_over_fsdp |
| 37 | +from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen |
| 38 | +from MaxText.maxtext_utils import all_gather_over_fsdp |
39 | 39 |
|
40 | 40 | # ------------------------------------------------------------------------------ |
41 | 41 | # The network: Transformer Definitions |
@@ -94,8 +94,12 @@ def setup(self): |
94 | 94 | # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. |
95 | 95 | # By convention, this is the last layer in the list. |
96 | 96 | mtp_layer = layer_types[-1] |
97 | | - self.mtp_block = MultiTokenPredictionBlock( |
98 | | - config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder |
| 97 | + self.mtp_block = multi_token_prediction_block_as_linen( |
| 98 | + config=self.config, |
| 99 | + mesh=self.mesh, |
| 100 | + transformer_layer_module=mtp_layer, |
| 101 | + decoder=self.decoder, |
| 102 | + rngs=self.make_rng("mtp_block"), |
99 | 103 | ) |
100 | 104 |
|
101 | 105 | def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): |
@@ -285,7 +289,15 @@ class Transformer(nnx.Module): |
285 | 289 | # Make new attributes required, so that all Transformer dependencies (train, decode, |
286 | 290 | # compile, etc) will error instead of silently use defaults. |
287 | 291 | # pylint: disable=attribute-defined-outside-init |
288 | | - def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs): |
| 292 | + def __init__( |
| 293 | + self, |
| 294 | + config: Config, |
| 295 | + mesh: Mesh, |
| 296 | + quant: Quant, |
| 297 | + *, |
| 298 | + model_mode: str = MODEL_MODE_TRAIN, |
| 299 | + rngs: nnx.Rngs, |
| 300 | + ): |
289 | 301 | """Initialize shared_embedding & decoder layers.""" |
290 | 302 | self.config = config |
291 | 303 | self.mesh = mesh |
@@ -347,8 +359,13 @@ def __init__(self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str |
347 | 359 | # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. |
348 | 360 | # By convention, this is the last layer in the list. |
349 | 361 | mtp_layer = layer_types[-1] |
350 | | - mtp_block_linen = MultiTokenPredictionBlock( |
351 | | - config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder |
| 362 | + mtp_block_linen = multi_token_prediction_block_as_linen( |
| 363 | + config=self.config, |
| 364 | + mesh=self.mesh, |
| 365 | + transformer_layer_module=mtp_layer, |
| 366 | + decoder=self.decoder, |
| 367 | + rngs=rngs, |
| 368 | + name="mtp_block", |
352 | 369 | ) |
353 | 370 | self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) |
354 | 371 |
|
@@ -593,7 +610,10 @@ def __call__( |
593 | 610 | page_state=page_state, |
594 | 611 | ) |
595 | 612 | all_model_weights = all_gather_over_fsdp( |
596 | | - self.model.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules |
| 613 | + self.model.variables, |
| 614 | + partition_spec, |
| 615 | + mesh=self.mesh, |
| 616 | + logical_axis_rules=self.config.logical_axis_rules, |
597 | 617 | ) |
598 | 618 |
|
599 | 619 | return self.model.apply( |
|
0 commit comments