Skip to content

Commit 64048ba

Browse files
committed
Migrate Pipeline to use NNX module.
1 parent bac545a commit 64048ba

File tree

6 files changed

+2815
-392
lines changed

6 files changed

+2815
-392
lines changed

src/MaxText/layers/decoders.py

Lines changed: 231 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,138 @@ def __call__(
260260
return inputs
261261

262262

263+
class SequentialNNXWrapper(nnx.Module):
264+
"""Wrapper that creates sequential decoder layers for pipeline stages.
265+
266+
This wrapper matches the decoder layer signature expected by Pipeline.
267+
"""
268+
269+
def __init__(
270+
self,
271+
decoder_layer_class: type,
272+
num_decoder_layers: int,
273+
config: Config,
274+
mesh: Mesh,
275+
model_mode: str,
276+
rngs: nnx.Rngs,
277+
quant: None | Quant = None,
278+
):
279+
"""Initialize wrapper with sequential decoder layers.
280+
281+
Args:
282+
decoder_layer_class: NNX decoder layer class to instantiate
283+
num_decoder_layers: Number of layers to create
284+
config: Model configuration
285+
mesh: Device mesh
286+
model_mode: 'train', 'eval', etc.
287+
rngs: RNG state
288+
quant: Quantization config
289+
"""
290+
self.sequential = SequentialBlockNNXDecoderLayers(
291+
decoder_layer_class=decoder_layer_class,
292+
num_decoder_layers=num_decoder_layers,
293+
config=config,
294+
mesh=mesh,
295+
model_mode=model_mode,
296+
rngs=rngs,
297+
quant=quant
298+
)
299+
300+
def __call__(self, *args, **kwargs):
301+
"""Forward pass through sequential layers."""
302+
return self.sequential(*args, **kwargs)
303+
304+
305+
class SequentialBlockNNXDecoderLayers(nnx.Module):
306+
"""Sequential unscanned series of NNX decoder layers."""
307+
308+
def __init__(
309+
self,
310+
decoder_layer_class: type,
311+
num_decoder_layers: int,
312+
config: Config,
313+
mesh: Mesh,
314+
model_mode: str,
315+
rngs: nnx.Rngs,
316+
quant: None | Quant = None,
317+
):
318+
"""Initialize multiple NNX decoder layer instances.
319+
320+
Args:
321+
decoder_layer_class: The NNX decoder layer class to instantiate
322+
num_decoder_layers: Number of decoder layers to create
323+
config: Model configuration
324+
mesh: Device mesh for sharding
325+
model_mode: 'train', 'eval', etc.
326+
rngs: RNG state for initialization
327+
quant: Quantization configuration
328+
"""
329+
self.config = config
330+
self.num_decoder_layers = num_decoder_layers
331+
332+
# Create multiple independent decoder layer instances
333+
# IMPORTANT: Store as individual attributes so NNX tracks them as pytree nodes
334+
# Regular Python lists are not tracked by NNX!
335+
for lyr in range(num_decoder_layers):
336+
layer = decoder_layer_class(
337+
config=config,
338+
mesh=mesh,
339+
model_mode=model_mode,
340+
rngs=rngs,
341+
quant=quant,
342+
)
343+
# Store as attribute with unique name so NNX can track it
344+
setattr(self, f'layer_{lyr}', layer)
345+
346+
def __call__(
347+
self,
348+
inputs: jnp.ndarray,
349+
decoder_segment_ids,
350+
decoder_positions,
351+
deterministic: bool,
352+
model_mode,
353+
slot: None | int = None,
354+
page_state: None | page_manager.PageState = None,
355+
) -> jnp.ndarray:
356+
"""Sequentially apply all decoder layers.
357+
358+
Args:
359+
inputs: Input tensor
360+
decoder_segment_ids: Segment IDs for attention masking
361+
decoder_positions: Position indices
362+
deterministic: Whether to use deterministic mode (no dropout)
363+
model_mode: 'train', 'eval', etc.
364+
slot: Optional slot index for paged attention
365+
page_state: Optional page state for paged attention
366+
367+
Returns:
368+
Output tensor after all layers, or (output, None) if scan_layers is True
369+
"""
370+
# Iterate over layer attributes (layer_0, layer_1, ...)
371+
for lyr in range(self.num_decoder_layers):
372+
layer = getattr(self, f'layer_{lyr}')
373+
outputs = layer(
374+
inputs,
375+
decoder_segment_ids,
376+
decoder_positions,
377+
deterministic,
378+
model_mode,
379+
slot=slot,
380+
page_state=page_state,
381+
)
382+
# Handle tuple outputs (e.g., from scan_layers)
383+
if self.config.scan_layers:
384+
inputs = outputs[0] # When scan_layers is True the decoder layers return (outputs, None).
385+
else:
386+
inputs = outputs
387+
388+
# Return format matching scan_layers configuration
389+
if self.config.scan_layers:
390+
return inputs, None
391+
else:
392+
return inputs
393+
394+
263395
class Decoder(nn.Module):
264396
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
265397

@@ -273,10 +405,17 @@ def setup(self):
273405
self.decoder_layer = self.get_decoder_layers()
274406
self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim)
275407
if self.config.using_pipeline_parallelism:
276-
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
408+
# Try to get pure NNX decoder classes for pipeline parallelism
409+
nnx_decoder_classes = self.get_nnx_decoder_layers()
410+
if nnx_decoder_classes is not None:
411+
# Use pure NNX classes for pipeline - pass the class, not instance
412+
pipeline_stage_module = self.get_pipeline_stage_module(nnx_decoder_classes, use_nnx=True)
413+
else:
414+
# Fallback to Linen-wrapped classes
415+
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer, use_nnx=False)
277416
remat_policy = self.get_remat_policy()
278-
self.pipeline_module = pipeline.Pipeline(
279-
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
417+
self.pipeline_module = pipeline.create_pipeline(
418+
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy, use_nnx=(nnx_decoder_classes is not None)
280419
)
281420

282421
def minimal_policy(self, with_context=False):
@@ -431,6 +570,63 @@ def get_decoder_layers(self):
431570
# Default case to handle any unknown decoder block types.
432571
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
433572

573+
def get_nnx_decoder_layers(self):
574+
"""Retrieves pure NNX decoder layer classes (without Linen wrappers) for pipeline parallelism.
575+
576+
Returns:
577+
A list containing one or more NNX Module classes for the decoder.
578+
"""
579+
match self.config.decoder_block:
580+
case DecoderBlockType.DEFAULT:
581+
# DecoderLayer is Linen-only, no NNX version available
582+
return None
583+
case DecoderBlockType.LLAMA2:
584+
return [llama2.LlamaDecoderLayer] # Pure NNX version
585+
case DecoderBlockType.MISTRAL:
586+
return [mistral.MistralDecoderLayer] if hasattr(mistral, 'MistralDecoderLayer') else None
587+
case DecoderBlockType.MIXTRAL:
588+
return [mixtral.MixtralDecoderLayer] if hasattr(mixtral, 'MixtralDecoderLayer') else None
589+
case DecoderBlockType.DEEPSEEK:
590+
# DeepSeek uses specific dense/MoE layers
591+
if self.config.use_batch_split_schedule:
592+
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
593+
else:
594+
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
595+
case DecoderBlockType.GEMMA:
596+
return [gemma.GemmaDecoderLayer] if hasattr(gemma, 'GemmaDecoderLayer') else None
597+
case DecoderBlockType.GEMMA2:
598+
return [gemma2.Gemma2DecoderLayer] if hasattr(gemma2, 'Gemma2DecoderLayer') else None
599+
case DecoderBlockType.GEMMA3:
600+
return [gemma3.Gemma3DecoderLayer] if hasattr(gemma3, 'Gemma3DecoderLayer') else None
601+
case DecoderBlockType.GPT3:
602+
return [gpt3.Gpt3DecoderLayer]
603+
case DecoderBlockType.GPT_OSS:
604+
# Check if pure NNX version exists
605+
if self.config.scan_layers:
606+
return [gpt_oss.GptOssScannableBlock] if hasattr(gpt_oss, 'GptOssScannableBlock') else None
607+
else:
608+
return [gpt_oss.GptOssDecoderLayer] if hasattr(gpt_oss, 'GptOssDecoderLayer') else None
609+
case DecoderBlockType.QWEN3:
610+
return [qwen3.Qwen3DecoderLayer] if hasattr(qwen3, 'Qwen3DecoderLayer') else None
611+
case DecoderBlockType.QWEN3_MOE:
612+
return [qwen3.Qwen3MoeDecoderLayer] if hasattr(qwen3, 'Qwen3MoeDecoderLayer') else None
613+
case DecoderBlockType.QWEN3_NEXT:
614+
if self.config.scan_layers:
615+
return [qwen3.Qwen3NextScannableBlock] if hasattr(qwen3, 'Qwen3NextScannableBlock') else None
616+
else:
617+
return [qwen3.Qwen3NextDecoderLayer] if hasattr(qwen3, 'Qwen3NextDecoderLayer') else None
618+
case DecoderBlockType.SIMPLE:
619+
return [simple_layer.SimpleDecoderLayer] # Pure NNX version
620+
case DecoderBlockType.SIMPLE_MLP:
621+
return [simple_layer.SimpleMlpDecoderLayer] # Pure NNX version
622+
case DecoderBlockType.LLAMA4:
623+
if self.config.scan_layers:
624+
return [llama4.Llama4ScannableBlock] if hasattr(llama4, 'Llama4ScannableBlock') else None
625+
else:
626+
return [llama4.Llama4DecoderLayer] if hasattr(llama4, 'Llama4DecoderLayer') else None
627+
case _:
628+
return None
629+
434630
def set_remat_policy(self, block_layers, policy):
435631
"""Set remat policy"""
436632
RemattedBlockLayers = []
@@ -510,8 +706,14 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
510706
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
511707
)
512708

513-
def get_pipeline_stage_module(self, decoder_blocks):
514-
"""get pipeline stage module"""
709+
def get_pipeline_stage_module(self, decoder_blocks, use_nnx=False):
710+
"""get pipeline stage module
711+
712+
Args:
713+
decoder_blocks: List of decoder layer classes (either Linen or NNX)
714+
use_nnx: If True, decoder_blocks are NNX classes and should be passed to Pipeline
715+
without instantiation. Pipeline will handle NNX instantiation with proper rngs.
716+
"""
515717

516718
def get_layer_to_pipeline(blocks, cfg):
517719
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
@@ -521,6 +723,30 @@ def get_layer_to_pipeline(blocks, cfg):
521723

522724
cfg = self.config
523725
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
726+
727+
# For NNX classes, return a class that Pipeline can instantiate
728+
if use_nnx:
729+
if cfg.num_layers_per_pipeline_stage == 1:
730+
# Return the NNX class itself, Pipeline will instantiate it
731+
return base_stage
732+
else:
733+
# For multiple layers per stage, return a partial wrapper
734+
max_logging.log(
735+
f"Pipeline: Creating sequential NNX wrapper with {cfg.num_layers_per_pipeline_stage} layers per stage"
736+
)
737+
# Return a lambda that creates the wrapper with the right parameters
738+
# Pipeline will call this with (config, mesh, model_mode, rngs, quant)
739+
return lambda config, mesh, model_mode, rngs, quant=None: SequentialNNXWrapper(
740+
decoder_layer_class=base_stage,
741+
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
742+
config=config,
743+
mesh=mesh,
744+
model_mode=model_mode,
745+
rngs=rngs,
746+
quant=quant
747+
)
748+
749+
# For Linen classes, instantiate as before
524750
if cfg.set_remat_policy_on_layers_per_stage:
525751
policy = self.get_remat_policy()
526752
base_stage = self.set_remat_policy([base_stage], policy)[0]

0 commit comments

Comments
 (0)