diff --git a/tensorrt_llm/_torch/models/modeling_qwen.py b/tensorrt_llm/_torch/models/modeling_qwen.py index fa20c4df002..df6d83e5b75 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen.py +++ b/tensorrt_llm/_torch/models/modeling_qwen.py @@ -171,7 +171,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) if spec_metadata is not None: spec_metadata.maybe_capture_hidden_states(self.layer_idx, @@ -230,7 +230,8 @@ def forward( attn_metadata=attn_metadata, residual=residual, mrope_config=mrope_config, - spec_metadata=spec_metadata) + spec_metadata=spec_metadata, + **kwargs) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -265,7 +266,8 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, mrope_config=mrope_config, - spec_metadata=spec_metadata) + spec_metadata=spec_metadata, + **kwargs) return self.logits_processor.forward( output, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index 49eacd068ea..67470a8eb1b 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -152,6 +152,7 @@ def forward( final_all_reduce_params=AllReduceParams( enable_allreduce=not self.disable_allreduce), cutlass_min_latency_mode=False, + **kwargs, ) if deepstack_embeds is not None and self.layer_idx in range( len(deepstack_embeds)): @@ -221,7 +222,10 @@ def forward( residual=residual, spec_metadata=spec_metadata, mrope_config=mrope_config, - deepstack_embeds=deepstack_embeds) + deepstack_embeds=deepstack_embeds, + **kwargs, + ) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states