Skip to content

Commit 985b79c

Browse files
authored
[TRTLLM-8348][feat] Speed up concat k and copy k_nope in context phase using torch.compile (#8044)
Signed-off-by: Tailing Yuan <[email protected]>
1 parent 1e2e851 commit 985b79c

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
6767
return metadata, attn_layer
6868

6969

70+
@torch.compile
71+
def compiled_copy_(dst, src):
72+
dst.copy_(src)
73+
74+
75+
@torch.compile
76+
def compiled_cat(tensors, dim):
77+
return torch.cat(tensors, dim)
78+
79+
7080
@torch.library.custom_op("trtllm::attn_custom_op_inplace",
7181
mutates_args=("output", ))
7282
def attn_custom_op_inplace(
@@ -1063,8 +1073,8 @@ def forward_context_default(
10631073
)
10641074

10651075
k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim)
1066-
k[..., :self.qk_nope_head_dim] = k_nope.view(-1, self.num_heads,
1067-
self.qk_nope_head_dim)
1076+
compiled_copy_(k[..., :self.qk_nope_head_dim],
1077+
k_nope.view(-1, self.num_heads, self.qk_nope_head_dim))
10681078
if self.apply_rotary_emb:
10691079
k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1,
10701080
self.qk_rope_head_dim)
@@ -1122,7 +1132,7 @@ def forward_context_with_cached_kv(
11221132
full_k_nope = full_k_nope.view(-1, self.num_heads,
11231133
self.qk_nope_head_dim)
11241134
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
1125-
full_k = torch.cat(
1135+
full_k = compiled_cat(
11261136
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
11271137
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
11281138

@@ -1217,7 +1227,7 @@ def forward_context_with_chunked_prefill(
12171227
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
12181228
self.qk_nope_head_dim)
12191229
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
1220-
chunked_k = torch.cat(
1230+
chunked_k = compiled_cat(
12211231
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
12221232
dim=-1)
12231233
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
@@ -1275,7 +1285,7 @@ def forward_context_with_chunked_prefill(
12751285

12761286
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
12771287
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
1278-
k = torch.cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
1288+
k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
12791289
k = k.view(-1, self.num_heads * self.qk_head_dim)
12801290

12811291
# copy q_lens to replace kv_lens_runtime

0 commit comments

Comments
 (0)