@@ -260,6 +260,138 @@ def __call__(
260260 return inputs
261261
262262
263+ class SequentialNNXWrapper (nnx .Module ):
264+ """Wrapper that creates sequential decoder layers for pipeline stages.
265+
266+ This wrapper matches the decoder layer signature expected by Pipeline.
267+ """
268+
269+ def __init__ (
270+ self ,
271+ decoder_layer_class : type ,
272+ num_decoder_layers : int ,
273+ config : Config ,
274+ mesh : Mesh ,
275+ model_mode : str ,
276+ rngs : nnx .Rngs ,
277+ quant : None | Quant = None ,
278+ ):
279+ """Initialize wrapper with sequential decoder layers.
280+
281+ Args:
282+ decoder_layer_class: NNX decoder layer class to instantiate
283+ num_decoder_layers: Number of layers to create
284+ config: Model configuration
285+ mesh: Device mesh
286+ model_mode: 'train', 'eval', etc.
287+ rngs: RNG state
288+ quant: Quantization config
289+ """
290+ self .sequential = SequentialBlockNNXDecoderLayers (
291+ decoder_layer_class = decoder_layer_class ,
292+ num_decoder_layers = num_decoder_layers ,
293+ config = config ,
294+ mesh = mesh ,
295+ model_mode = model_mode ,
296+ rngs = rngs ,
297+ quant = quant
298+ )
299+
300+ def __call__ (self , * args , ** kwargs ):
301+ """Forward pass through sequential layers."""
302+ return self .sequential (* args , ** kwargs )
303+
304+
305+ class SequentialBlockNNXDecoderLayers (nnx .Module ):
306+ """Sequential unscanned series of NNX decoder layers."""
307+
308+ def __init__ (
309+ self ,
310+ decoder_layer_class : type ,
311+ num_decoder_layers : int ,
312+ config : Config ,
313+ mesh : Mesh ,
314+ model_mode : str ,
315+ rngs : nnx .Rngs ,
316+ quant : None | Quant = None ,
317+ ):
318+ """Initialize multiple NNX decoder layer instances.
319+
320+ Args:
321+ decoder_layer_class: The NNX decoder layer class to instantiate
322+ num_decoder_layers: Number of decoder layers to create
323+ config: Model configuration
324+ mesh: Device mesh for sharding
325+ model_mode: 'train', 'eval', etc.
326+ rngs: RNG state for initialization
327+ quant: Quantization configuration
328+ """
329+ self .config = config
330+ self .num_decoder_layers = num_decoder_layers
331+
332+ # Create multiple independent decoder layer instances
333+ # IMPORTANT: Store as individual attributes so NNX tracks them as pytree nodes
334+ # Regular Python lists are not tracked by NNX!
335+ for lyr in range (num_decoder_layers ):
336+ layer = decoder_layer_class (
337+ config = config ,
338+ mesh = mesh ,
339+ model_mode = model_mode ,
340+ rngs = rngs ,
341+ quant = quant ,
342+ )
343+ # Store as attribute with unique name so NNX can track it
344+ setattr (self , f'layer_{ lyr } ' , layer )
345+
346+ def __call__ (
347+ self ,
348+ inputs : jnp .ndarray ,
349+ decoder_segment_ids ,
350+ decoder_positions ,
351+ deterministic : bool ,
352+ model_mode ,
353+ slot : None | int = None ,
354+ page_state : None | page_manager .PageState = None ,
355+ ) -> jnp .ndarray :
356+ """Sequentially apply all decoder layers.
357+
358+ Args:
359+ inputs: Input tensor
360+ decoder_segment_ids: Segment IDs for attention masking
361+ decoder_positions: Position indices
362+ deterministic: Whether to use deterministic mode (no dropout)
363+ model_mode: 'train', 'eval', etc.
364+ slot: Optional slot index for paged attention
365+ page_state: Optional page state for paged attention
366+
367+ Returns:
368+ Output tensor after all layers, or (output, None) if scan_layers is True
369+ """
370+ # Iterate over layer attributes (layer_0, layer_1, ...)
371+ for lyr in range (self .num_decoder_layers ):
372+ layer = getattr (self , f'layer_{ lyr } ' )
373+ outputs = layer (
374+ inputs ,
375+ decoder_segment_ids ,
376+ decoder_positions ,
377+ deterministic ,
378+ model_mode ,
379+ slot = slot ,
380+ page_state = page_state ,
381+ )
382+ # Handle tuple outputs (e.g., from scan_layers)
383+ if self .config .scan_layers :
384+ inputs = outputs [0 ] # When scan_layers is True the decoder layers return (outputs, None).
385+ else :
386+ inputs = outputs
387+
388+ # Return format matching scan_layers configuration
389+ if self .config .scan_layers :
390+ return inputs , None
391+ else :
392+ return inputs
393+
394+
263395class Decoder (nn .Module ):
264396 """A stack of decoder layers as a part of an encoder-decoder architecture."""
265397
@@ -273,10 +405,17 @@ def setup(self):
273405 self .decoder_layer = self .get_decoder_layers ()
274406 self .norm_layer = self .get_norm_layer (num_features = self .config .emb_dim )
275407 if self .config .using_pipeline_parallelism :
276- pipeline_stage_module = self .get_pipeline_stage_module (self .decoder_layer )
408+ # Try to get pure NNX decoder classes for pipeline parallelism
409+ nnx_decoder_classes = self .get_nnx_decoder_layers ()
410+ if nnx_decoder_classes is not None :
411+ # Use pure NNX classes for pipeline - pass the class, not instance
412+ pipeline_stage_module = self .get_pipeline_stage_module (nnx_decoder_classes , use_nnx = True )
413+ else :
414+ # Fallback to Linen-wrapped classes
415+ pipeline_stage_module = self .get_pipeline_stage_module (self .decoder_layer , use_nnx = False )
277416 remat_policy = self .get_remat_policy ()
278- self .pipeline_module = pipeline .Pipeline (
279- config = self .config , mesh = self .mesh , layers = pipeline_stage_module , remat_policy = remat_policy
417+ self .pipeline_module = pipeline .create_pipeline (
418+ config = self .config , mesh = self .mesh , layers = pipeline_stage_module , remat_policy = remat_policy , use_nnx = ( nnx_decoder_classes is not None )
280419 )
281420
282421 def minimal_policy (self , with_context = False ):
@@ -431,6 +570,63 @@ def get_decoder_layers(self):
431570 # Default case to handle any unknown decoder block types.
432571 raise ValueError (f"Incorrect decoder_block name { self .config .decoder_block .value = } " )
433572
573+ def get_nnx_decoder_layers (self ):
574+ """Retrieves pure NNX decoder layer classes (without Linen wrappers) for pipeline parallelism.
575+
576+ Returns:
577+ A list containing one or more NNX Module classes for the decoder.
578+ """
579+ match self .config .decoder_block :
580+ case DecoderBlockType .DEFAULT :
581+ # DecoderLayer is Linen-only, no NNX version available
582+ return None
583+ case DecoderBlockType .LLAMA2 :
584+ return [llama2 .LlamaDecoderLayer ] # Pure NNX version
585+ case DecoderBlockType .MISTRAL :
586+ return [mistral .MistralDecoderLayer ] if hasattr (mistral , 'MistralDecoderLayer' ) else None
587+ case DecoderBlockType .MIXTRAL :
588+ return [mixtral .MixtralDecoderLayer ] if hasattr (mixtral , 'MixtralDecoderLayer' ) else None
589+ case DecoderBlockType .DEEPSEEK :
590+ # DeepSeek uses specific dense/MoE layers
591+ if self .config .use_batch_split_schedule :
592+ return [deepseek_batchsplit .DeepSeekDenseLayer , deepseek_batchsplit .DeepSeekMoELayer ]
593+ else :
594+ return [deepseek .DeepSeekDenseLayer , deepseek .DeepSeekMoELayer ]
595+ case DecoderBlockType .GEMMA :
596+ return [gemma .GemmaDecoderLayer ] if hasattr (gemma , 'GemmaDecoderLayer' ) else None
597+ case DecoderBlockType .GEMMA2 :
598+ return [gemma2 .Gemma2DecoderLayer ] if hasattr (gemma2 , 'Gemma2DecoderLayer' ) else None
599+ case DecoderBlockType .GEMMA3 :
600+ return [gemma3 .Gemma3DecoderLayer ] if hasattr (gemma3 , 'Gemma3DecoderLayer' ) else None
601+ case DecoderBlockType .GPT3 :
602+ return [gpt3 .Gpt3DecoderLayer ]
603+ case DecoderBlockType .GPT_OSS :
604+ # Check if pure NNX version exists
605+ if self .config .scan_layers :
606+ return [gpt_oss .GptOssScannableBlock ] if hasattr (gpt_oss , 'GptOssScannableBlock' ) else None
607+ else :
608+ return [gpt_oss .GptOssDecoderLayer ] if hasattr (gpt_oss , 'GptOssDecoderLayer' ) else None
609+ case DecoderBlockType .QWEN3 :
610+ return [qwen3 .Qwen3DecoderLayer ] if hasattr (qwen3 , 'Qwen3DecoderLayer' ) else None
611+ case DecoderBlockType .QWEN3_MOE :
612+ return [qwen3 .Qwen3MoeDecoderLayer ] if hasattr (qwen3 , 'Qwen3MoeDecoderLayer' ) else None
613+ case DecoderBlockType .QWEN3_NEXT :
614+ if self .config .scan_layers :
615+ return [qwen3 .Qwen3NextScannableBlock ] if hasattr (qwen3 , 'Qwen3NextScannableBlock' ) else None
616+ else :
617+ return [qwen3 .Qwen3NextDecoderLayer ] if hasattr (qwen3 , 'Qwen3NextDecoderLayer' ) else None
618+ case DecoderBlockType .SIMPLE :
619+ return [simple_layer .SimpleDecoderLayer ] # Pure NNX version
620+ case DecoderBlockType .SIMPLE_MLP :
621+ return [simple_layer .SimpleMlpDecoderLayer ] # Pure NNX version
622+ case DecoderBlockType .LLAMA4 :
623+ if self .config .scan_layers :
624+ return [llama4 .Llama4ScannableBlock ] if hasattr (llama4 , 'Llama4ScannableBlock' ) else None
625+ else :
626+ return [llama4 .Llama4DecoderLayer ] if hasattr (llama4 , 'Llama4DecoderLayer' ) else None
627+ case _:
628+ return None
629+
434630 def set_remat_policy (self , block_layers , policy ):
435631 """Set remat policy"""
436632 RemattedBlockLayers = []
@@ -510,8 +706,14 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
510706 config = cfg , mesh = mesh , name = metadata_axis_name , quant = self .quant , ** kwargs # pytype: disable=wrong-keyword-args
511707 )
512708
513- def get_pipeline_stage_module (self , decoder_blocks ):
514- """get pipeline stage module"""
709+ def get_pipeline_stage_module (self , decoder_blocks , use_nnx = False ):
710+ """get pipeline stage module
711+
712+ Args:
713+ decoder_blocks: List of decoder layer classes (either Linen or NNX)
714+ use_nnx: If True, decoder_blocks are NNX classes and should be passed to Pipeline
715+ without instantiation. Pipeline will handle NNX instantiation with proper rngs.
716+ """
515717
516718 def get_layer_to_pipeline (blocks , cfg ):
517719 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
@@ -521,6 +723,30 @@ def get_layer_to_pipeline(blocks, cfg):
521723
522724 cfg = self .config
523725 base_stage = get_layer_to_pipeline (decoder_blocks , cfg )
726+
727+ # For NNX classes, return a class that Pipeline can instantiate
728+ if use_nnx :
729+ if cfg .num_layers_per_pipeline_stage == 1 :
730+ # Return the NNX class itself, Pipeline will instantiate it
731+ return base_stage
732+ else :
733+ # For multiple layers per stage, return a partial wrapper
734+ max_logging .log (
735+ f"Pipeline: Creating sequential NNX wrapper with { cfg .num_layers_per_pipeline_stage } layers per stage"
736+ )
737+ # Return a lambda that creates the wrapper with the right parameters
738+ # Pipeline will call this with (config, mesh, model_mode, rngs, quant)
739+ return lambda config , mesh , model_mode , rngs , quant = None : SequentialNNXWrapper (
740+ decoder_layer_class = base_stage ,
741+ num_decoder_layers = cfg .num_layers_per_pipeline_stage ,
742+ config = config ,
743+ mesh = mesh ,
744+ model_mode = model_mode ,
745+ rngs = rngs ,
746+ quant = quant
747+ )
748+
749+ # For Linen classes, instantiate as before
524750 if cfg .set_remat_policy_on_layers_per_stage :
525751 policy = self .get_remat_policy ()
526752 base_stage = self .set_remat_policy ([base_stage ], policy )[0 ]
0 commit comments