Skip to content

Commit 8226ef2

Browse files
authored
Revert "[None][feat] support attention dp for qwen3 dense model" (#7765)
1 parent e7c1569 commit 8226ef2

File tree

1 file changed

+1
-16
lines changed

1 file changed

+1
-16
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from ..attention_backend import AttentionMetadata
1010
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
11-
from ..distributed import AllReduceParams
1211
from ..model_config import ModelConfig
1312
from ..modules.decoder_layer import DecoderLayer
1413
from ..modules.embedding import Embedding
@@ -83,8 +82,6 @@ def __init__(
8382
model_config,
8483
layer_idx=layer_idx,
8584
)
86-
self.mapping = model_config.mapping
87-
self.enable_attention_dp = self.mapping.enable_attention_dp
8885

8986
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
9087
# and https://nvbugspro.nvidia.com/bug/5505402)
@@ -95,7 +92,6 @@ def __init__(
9592
intermediate_size=config.intermediate_size,
9693
bias=config.mlp_bias if hasattr(config, "mlp_bias") else False,
9794
dtype=config.torch_dtype,
98-
overridden_tp_size=1 if self.enable_attention_dp else None,
9995
config=model_config,
10096
disable_deep_gemm=disable_deep_gemm,
10197
)
@@ -106,8 +102,6 @@ def __init__(
106102
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
107103
eps=config.rms_norm_eps,
108104
dtype=config.torch_dtype)
109-
self.disable_allreduce = (self.mapping.tp_size == 1
110-
or self.enable_attention_dp)
111105

112106
def forward(
113107
self,
@@ -132,22 +126,13 @@ def forward(
132126
hidden_states=hidden_states,
133127
attn_metadata=attn_metadata,
134128
mrope_config=mrope_config,
135-
all_reduce_params=AllReduceParams(
136-
enable_allreduce=not self.disable_allreduce),
137129
**kwargs,
138130
)
139131

140132
# Fully Connected
141133
hidden_states, residual = self.post_attention_layernorm(
142134
hidden_states, residual)
143-
hidden_states = self.mlp(
144-
hidden_states,
145-
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
146-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
147-
final_all_reduce_params=AllReduceParams(
148-
enable_allreduce=not self.disable_allreduce),
149-
cutlass_min_latency_mode=False,
150-
)
135+
hidden_states = self.mlp(hidden_states)
151136

152137
if spec_metadata is not None:
153138
spec_metadata.maybe_capture_hidden_states(self.layer_idx,

0 commit comments

Comments
 (0)