Skip to content

Commit dc32bac

Browse files
authored
[#4745][fix] Pass lora_params through Qwen2/3 model forward (#10174)
Signed-off-by: Kanghwan Jang <[email protected]>
1 parent cbf8357 commit dc32bac

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tensorrt_llm/_torch/models/modeling_qwen.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(
171171
# Fully Connected
172172
hidden_states, residual = self.post_attention_layernorm(
173173
hidden_states, residual)
174-
hidden_states = self.mlp(hidden_states)
174+
hidden_states = self.mlp(hidden_states, **kwargs)
175175

176176
if spec_metadata is not None:
177177
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
@@ -230,7 +230,8 @@ def forward(
230230
attn_metadata=attn_metadata,
231231
residual=residual,
232232
mrope_config=mrope_config,
233-
spec_metadata=spec_metadata)
233+
spec_metadata=spec_metadata,
234+
**kwargs)
234235

235236
hidden_states, _ = self.norm(hidden_states, residual)
236237
return hidden_states
@@ -265,7 +266,8 @@ def forward(
265266
position_ids=position_ids,
266267
inputs_embeds=inputs_embeds,
267268
mrope_config=mrope_config,
268-
spec_metadata=spec_metadata)
269+
spec_metadata=spec_metadata,
270+
**kwargs)
269271

270272
return self.logits_processor.forward(
271273
output,

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def forward(
152152
final_all_reduce_params=AllReduceParams(
153153
enable_allreduce=not self.disable_allreduce),
154154
cutlass_min_latency_mode=False,
155+
**kwargs,
155156
)
156157
if deepstack_embeds is not None and self.layer_idx in range(
157158
len(deepstack_embeds)):
@@ -221,7 +222,10 @@ def forward(
221222
residual=residual,
222223
spec_metadata=spec_metadata,
223224
mrope_config=mrope_config,
224-
deepstack_embeds=deepstack_embeds)
225+
deepstack_embeds=deepstack_embeds,
226+
**kwargs,
227+
)
228+
225229
hidden_states, _ = self.norm(hidden_states, residual)
226230
return hidden_states
227231

0 commit comments

Comments
 (0)