88
99from ..attention_backend import AttentionMetadata
1010from ..attention_backend .interface import PositionalEmbeddingParams , RopeParams
11- from ..distributed import AllReduceParams
1211from ..model_config import ModelConfig
1312from ..modules .decoder_layer import DecoderLayer
1413from ..modules .embedding import Embedding
@@ -83,8 +82,6 @@ def __init__(
8382 model_config ,
8483 layer_idx = layer_idx ,
8584 )
86- self .mapping = model_config .mapping
87- self .enable_attention_dp = self .mapping .enable_attention_dp
8885
8986 # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
9087 # and https://nvbugspro.nvidia.com/bug/5505402)
@@ -95,7 +92,6 @@ def __init__(
9592 intermediate_size = config .intermediate_size ,
9693 bias = config .mlp_bias if hasattr (config , "mlp_bias" ) else False ,
9794 dtype = config .torch_dtype ,
98- overridden_tp_size = 1 if self .enable_attention_dp else None ,
9995 config = model_config ,
10096 disable_deep_gemm = disable_deep_gemm ,
10197 )
@@ -106,8 +102,6 @@ def __init__(
106102 self .post_attention_layernorm = RMSNorm (hidden_size = config .hidden_size ,
107103 eps = config .rms_norm_eps ,
108104 dtype = config .torch_dtype )
109- self .disable_allreduce = (self .mapping .tp_size == 1
110- or self .enable_attention_dp )
111105
112106 def forward (
113107 self ,
@@ -132,22 +126,13 @@ def forward(
132126 hidden_states = hidden_states ,
133127 attn_metadata = attn_metadata ,
134128 mrope_config = mrope_config ,
135- all_reduce_params = AllReduceParams (
136- enable_allreduce = not self .disable_allreduce ),
137129 ** kwargs ,
138130 )
139131
140132 # Fully Connected
141133 hidden_states , residual = self .post_attention_layernorm (
142134 hidden_states , residual )
143- hidden_states = self .mlp (
144- hidden_states ,
145- all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
146- all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
147- final_all_reduce_params = AllReduceParams (
148- enable_allreduce = not self .disable_allreduce ),
149- cutlass_min_latency_mode = False ,
150- )
135+ hidden_states = self .mlp (hidden_states )
151136
152137 if spec_metadata is not None :
153138 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
0 commit comments