3232from ..modules .decoder_layer import DecoderLayer
3333from ..modules .embedding import Embedding
3434from ..modules .fused_moe import MoEWeightLoadingMode , create_moe
35- from ..modules .linear import Linear
35+ from ..modules .linear import Linear , TensorParallelMode
3636from ..modules .mamba .mamba2_mixer import Mamba2Mixer
3737from ..modules .mlp import MLP
3838from ..modules .multi_stream_utils import maybe_execute_in_parallel
@@ -85,8 +85,10 @@ def __init__(
8585 self ,
8686 model_config : ModelConfig [NemotronHConfig ],
8787 layer_idx : int ,
88+ reduce_output : bool = False ,
8889 ):
8990 config = model_config .pretrained_config
91+
9092 super ().__init__ (
9193 hidden_size = config .hidden_size ,
9294 num_attention_heads = config .num_attention_heads ,
@@ -97,6 +99,7 @@ def __init__(
9799 layer_idx = layer_idx ,
98100 dtype = config .torch_dtype ,
99101 config = model_config ,
102+ reduce_output = reduce_output ,
100103 )
101104
102105 def forward (
@@ -154,6 +157,7 @@ def __init__(
154157 shared_expert_intermediate_size = (
155158 config .moe_shared_expert_intermediate_size *
156159 config .n_shared_experts )
160+
157161 self .shared_experts = MLP (
158162 hidden_size = config .hidden_size ,
159163 intermediate_size = shared_expert_intermediate_size ,
@@ -193,11 +197,14 @@ def __init__(
193197 activation_type = self .activation_type ,
194198 )
195199
196- # AllReduce for combining shared and routed expert outputs in multi-GPU settings.
197- self .allreduce = AllReduce (
198- mapping = model_config .mapping ,
199- strategy = model_config .allreduce_strategy ,
200- )
200+ if not model_config .mapping .enable_attention_dp :
201+ # AllReduce for combining shared and routed expert outputs in multi-GPU settings.
202+ self .allreduce = AllReduce (
203+ mapping = model_config .mapping ,
204+ strategy = model_config .allreduce_strategy ,
205+ )
206+ else :
207+ self .allreduce = None
201208
202209 # Setup latent projection layers.
203210 # These layers should NOT be TP-sharded to ensure MoE receives
@@ -322,7 +329,11 @@ def __init__(
322329 elif layer_type == "-" :
323330 self .mixer = MLPLayer (model_config , layer_idx )
324331 elif layer_type == "*" :
325- self .mixer = TransformerLayer (model_config , layer_idx )
332+ self .mixer = TransformerLayer (
333+ model_config ,
334+ layer_idx ,
335+ reduce_output = not model_config .mapping .enable_attention_dp
336+ and model_config .mapping .tp_size > 1 )
326337 elif layer_type == "E" :
327338 self .mixer = NemotronHMOE (model_config ,
328339 layer_idx = layer_idx ,
@@ -365,12 +376,24 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
365376 aux_stream_list [2 ],
366377 }
367378
368- # calculate embeddings
369- self .embed_tokens = Embedding (
370- config .vocab_size ,
371- config .hidden_size ,
372- dtype = config .torch_dtype ,
373- )
379+ if model_config .mapping .enable_attention_dp :
380+ # When attention_dp is enabled, we cannot do all_reduce since
381+ # the problem size of different ranks are different.
382+ # So, we don't do parallelism here.
383+ self .embed_tokens = Embedding (
384+ config .vocab_size ,
385+ config .hidden_size ,
386+ dtype = config .torch_dtype ,
387+ )
388+ else :
389+ self .embed_tokens = Embedding (
390+ config .vocab_size ,
391+ config .hidden_size ,
392+ dtype = config .torch_dtype ,
393+ mapping = model_config .mapping ,
394+ tensor_parallel_mode = TensorParallelMode .COLUMN ,
395+ gather_output = True ,
396+ )
374397
375398 # create layers
376399 layers = []
0 commit comments