Skip to content

Commit bd740c9

Browse files
authored
[None][fix] Avoid unnecessary concat in attn_output_gate case. (#8094)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 6c4cc4c commit bd740c9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,9 @@ def forward(
538538
t.reshape(*orig_shape, -1) for t in torch.chunk(
539539
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
540540
]
541-
### TODO: avoid the redundant split and concat
542-
qkv = torch.concat([q, k, v], dim=-1)
541+
else:
542+
q, k, v = qkv, None, None
543543

544-
q, k, v = qkv, None, None
545544
q, k, v = self.apply_rope(q, k, v, position_ids)
546545
q, k, v = self.convert_qkv(q, k, v)
547546

tensorrt_llm/_torch/modules/qk_norm_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
249249
else:
250250
return q, k, v
251251

252-
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
253252
qkv = q
253+
if k is not None and v is not None:
254+
qkv = torch.concat([q, k, v], dim=-1)
254255
return self.apply_qk_norm_rope(qkv, position_ids)

0 commit comments

Comments
 (0)