3131from tensorrt_llm .logger import logger
3232
3333from ..attention_backend import AttentionMetadata
34- from ..distributed import AllReduce
34+ from ..distributed import AllReduce , AllReduceFusionOp , AllReduceParams
3535from ..model_config import ModelConfig
3636from ..modules .attention import Attention
3737from ..modules .decoder_layer import DecoderLayer
@@ -59,6 +59,7 @@ def __init__(
5959 self ,
6060 model_config : ModelConfig [NemotronHConfig ],
6161 layer_idx : int ,
62+ reduce_output : bool = True ,
6263 ):
6364 config = model_config .pretrained_config
6465 if isinstance (config .intermediate_size , list ):
@@ -76,6 +77,7 @@ def __init__(
7677 activation = relu2 ,
7778 dtype = config .torch_dtype ,
7879 config = model_config ,
80+ reduce_output = reduce_output ,
7981 )
8082 self .layer_idx = layer_idx
8183
@@ -119,7 +121,8 @@ def forward(
119121 ) -> torch .Tensor :
120122 return super ().forward (position_ids = None ,
121123 hidden_states = hidden_states ,
122- attn_metadata = attn_metadata )
124+ attn_metadata = attn_metadata ,
125+ ** kwargs )
123126
124127
125128# Ref code: https://huggingface.co/nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024/blob/main/modeling_nemotron_h.py#L818
@@ -130,6 +133,7 @@ def __init__(
130133 model_config : ModelConfig [PretrainedConfig ],
131134 layer_idx : int ,
132135 aux_stream_dict : dict [AuxStreamType , torch .cuda .Stream ],
136+ reduce_output : bool = False ,
133137 ):
134138 super ().__init__ ()
135139
@@ -226,8 +230,7 @@ def __init__(
226230 activation_type = self .activation_type ,
227231 )
228232
229- if not model_config .mapping .enable_attention_dp :
230- # AllReduce for combining shared and routed expert outputs in multi-GPU settings.
233+ if reduce_output :
231234 self .allreduce = AllReduce (
232235 mapping = model_config .mapping ,
233236 strategy = model_config .allreduce_strategy ,
@@ -324,8 +327,10 @@ def _compute_routed_output():
324327 final_hidden_states = shared_output + routed_output
325328
326329 # Perform all-reduce after combining outputs for multi-GPU support.
327- if not self .enable_attention_dp and self .mapping .tp_size > 1 :
328- final_hidden_states = self .allreduce (final_hidden_states )
330+ if self .allreduce is not None :
331+ final_hidden_states = self .allreduce (
332+ final_hidden_states ,
333+ all_reduce_params = kwargs .get ('all_reduce_params' ))
329334
330335 return final_hidden_states .view (orig_shape )
331336
@@ -341,6 +346,7 @@ def __init__(
341346 # * -> TransformerLayer
342347 layer_type : str ,
343348 aux_stream_dict : dict [AuxStreamType , torch .cuda .Stream ],
349+ fuse_allreduce_norm : bool = False ,
344350 ):
345351 super ().__init__ ()
346352
@@ -373,6 +379,13 @@ def __init__(
373379 )
374380 self .is_nvfp4 = False
375381
382+ # fuse_allreduce_norm is the model-level flag. When enabled, ALL
383+ # layers defer mixer AllReduce to the next layer's pre_allreduce (or
384+ # the model's final_allreduce). Only layers 1+ create a pre_allreduce
385+ # module; layer 0's input is already reduced from the embedding.
386+ self .fuse_allreduce_norm = fuse_allreduce_norm
387+ self .is_moe_layer = (layer_type == "E" )
388+
376389 self .norm = RMSNorm (
377390 hidden_size = config .hidden_size ,
378391 eps = config .rms_norm_eps ,
@@ -382,9 +395,22 @@ def __init__(
382395 quantize_type = "nvfp4" if self .is_nvfp4 else None ,
383396 # Enable high precision output for MoE layer (only with NVFP4).
384397 # It might be overridden in `_try_attach_nvfp4_scale` function.
385- return_hp_output = layer_type == "E" and self .is_nvfp4 ,
398+ return_hp_output = self . is_moe_layer and self .is_nvfp4 ,
386399 )
387400
401+ if fuse_allreduce_norm and layer_idx > 0 :
402+ self .pre_allreduce = AllReduce (
403+ mapping = model_config .mapping ,
404+ strategy = model_config .allreduce_strategy ,
405+ )
406+
407+ # Mixer creation. The fuse_allreduce_norm optimization is orthogonal
408+ # to AllReduce topology: Transformer/MoE gate it at forward time via
409+ # AllReduceParams; MLP/Mamba gate it at init time via reduce_output
410+ # (their base classes don't thread all_reduce_params through forward).
411+ has_tp_allreduce = (not model_config .mapping .enable_attention_dp
412+ and model_config .mapping .tp_size > 1 )
413+
388414 if layer_type == "M" :
389415 self .mixer = Mamba2Mixer (
390416 d_model = config .hidden_size ,
@@ -399,19 +425,27 @@ def __init__(
399425 dtype = config .torch_dtype ,
400426 config = model_config ,
401427 )
428+ if fuse_allreduce_norm :
429+ self .mixer .out_proj .reduce_output = False
402430 elif layer_type == "-" :
403- self .mixer = MLPLayer (model_config , layer_idx )
431+ self .mixer = MLPLayer (
432+ model_config ,
433+ layer_idx ,
434+ reduce_output = not fuse_allreduce_norm ,
435+ )
404436 elif layer_type == "*" :
405437 self .mixer = TransformerLayer (
406438 model_config ,
407439 layer_idx ,
408- reduce_output = not model_config .mapping .enable_attention_dp
409- and model_config .mapping .tp_size > 1 ,
440+ reduce_output = has_tp_allreduce ,
410441 )
411442 elif layer_type == "E" :
412- self .mixer = NemotronHMOE (model_config ,
413- layer_idx = layer_idx ,
414- aux_stream_dict = aux_stream_dict )
443+ self .mixer = NemotronHMOE (
444+ model_config ,
445+ layer_idx = layer_idx ,
446+ aux_stream_dict = aux_stream_dict ,
447+ reduce_output = has_tp_allreduce ,
448+ )
415449 else :
416450 raise ValueError (f"{ layer_type } is not supported" )
417451
@@ -436,7 +470,7 @@ def _try_attach_nvfp4_scale(self):
436470
437471 # Special handling for MoE layer: fetch shared_expert.up_proj.input_scale
438472 # as representation of the input scale.
439- if self .layer_type == "E" :
473+ if self .is_moe_layer :
440474 if (hasattr (self .mixer , "shared_experts" )
441475 and self .mixer .shared_experts is not None
442476 and hasattr (self .mixer .shared_experts , "up_proj" )
@@ -463,16 +497,50 @@ def forward(
463497 if residual is None :
464498 residual = torch .zeros_like (hidden_states )
465499
466- if self .norm .return_hp_output :
500+ if hasattr (self , 'pre_allreduce' ):
501+ norm = self .norm
502+ has_nvfp4_scale = hasattr (norm , 'nvfp4_scale' )
503+ if norm .is_nvfp4 and has_nvfp4_scale and norm .return_hp_output :
504+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4
505+ elif norm .is_nvfp4 and has_nvfp4_scale :
506+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
507+ else :
508+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
509+ all_reduce_params = AllReduceParams (
510+ fusion_op = fusion_op ,
511+ residual = residual ,
512+ norm_weight = norm .weight ,
513+ eps = norm .variance_epsilon ,
514+ trigger_completion_at_end = False ,
515+ ** (dict (scale = norm .nvfp4_scale )
516+ if has_nvfp4_scale and norm .is_nvfp4 else {}),
517+ )
518+ result = self .pre_allreduce (hidden_states ,
519+ all_reduce_params = all_reduce_params )
520+ if fusion_op == AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 :
521+ norm_out , act_fp4 , act_sf , residual = result
522+ hidden_states = (Fp4QuantizedTensor (act_fp4 , act_sf ), norm_out )
523+ elif fusion_op == AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 :
524+ act_fp4 , act_sf , residual = result
525+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
526+ else :
527+ hidden_states , residual = result
528+ elif self .norm .return_hp_output :
467529 hidden_states , residual , high_precision_normed_output = self .norm (
468530 hidden_states , residual )
469531 hidden_states = (hidden_states , high_precision_normed_output )
470532 else :
471533 hidden_states , residual = self .norm (hidden_states , residual )
472- hidden_states = self .mixer (hidden_states ,
473- attn_metadata ,
474- spec_metadata = spec_metadata ,
475- ** kwargs )
534+
535+ # When fuse_allreduce_norm is active, tell Transformer/MoE mixers to
536+ # skip their own AllReduce (it is handled by pre_allreduce /
537+ # final_allreduce instead). MLP/Mamba ignore this kwarg; their
538+ # reduce_output was set at init time.
539+ mixer_kwargs = dict (spec_metadata = spec_metadata , ** kwargs )
540+ if self .fuse_allreduce_norm :
541+ mixer_kwargs ['all_reduce_params' ] = AllReduceParams (
542+ enable_allreduce = False )
543+ hidden_states = self .mixer (hidden_states , attn_metadata , ** mixer_kwargs )
476544
477545 if spec_metadata is not None and spec_metadata .is_layer_capture (
478546 self .layer_idx ):
@@ -519,14 +587,20 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
519587 gather_output = True ,
520588 )
521589
590+ self .fuse_allreduce_norm = (not model_config .mapping .enable_attention_dp
591+ and model_config .mapping .tp_size > 1 )
592+
522593 # create layers
523594 layers = []
524595 for layer_idx , layer_type in enumerate (config .hybrid_override_pattern ):
525596 layers .append (
526- NemotronHLayer (model_config ,
527- layer_idx ,
528- layer_type ,
529- aux_stream_dict = self .aux_stream_dict ))
597+ NemotronHLayer (
598+ model_config ,
599+ layer_idx ,
600+ layer_type ,
601+ aux_stream_dict = self .aux_stream_dict ,
602+ fuse_allreduce_norm = self .fuse_allreduce_norm ,
603+ ))
530604 self .layers = nn .ModuleList (layers )
531605 self .num_hidden_layers = config .num_hidden_layers
532606
@@ -537,6 +611,13 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
537611 dtype = config .torch_dtype ,
538612 )
539613
614+ # AllReduce for fusing with final norm (after last layer's mixer)
615+ if self .fuse_allreduce_norm :
616+ self .final_allreduce = AllReduce (
617+ mapping = model_config .mapping ,
618+ strategy = model_config .allreduce_strategy ,
619+ )
620+
540621 def forward (
541622 self ,
542623 attn_metadata : AttentionMetadata ,
@@ -567,7 +648,19 @@ def forward(
567648 spec_metadata = spec_metadata ,
568649 mamba_metadata = mamba_metadata ,
569650 )
570- hidden_states , _ = self .norm_f (hidden_states , residual )
651+
652+ if self .fuse_allreduce_norm :
653+ hidden_states , _ = self .final_allreduce (
654+ hidden_states ,
655+ all_reduce_params = AllReduceParams (
656+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
657+ residual = residual ,
658+ norm_weight = self .norm_f .weight ,
659+ eps = self .norm_f .variance_epsilon ,
660+ trigger_completion_at_end = False ,
661+ ))
662+ else :
663+ hidden_states , _ = self .norm_f (hidden_states , residual )
571664 return hidden_states
572665
573666
0 commit comments