@@ -67,6 +67,16 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
6767 return metadata , attn_layer
6868
6969
70+ @torch .compile
71+ def compiled_copy_ (dst , src ):
72+ dst .copy_ (src )
73+
74+
75+ @torch .compile
76+ def compiled_cat (tensors , dim ):
77+ return torch .cat (tensors , dim )
78+
79+
7080@torch .library .custom_op ("trtllm::attn_custom_op_inplace" ,
7181 mutates_args = ("output" , ))
7282def attn_custom_op_inplace (
@@ -1063,8 +1073,8 @@ def forward_context_default(
10631073 )
10641074
10651075 k = torch .empty_like (q ).view (- 1 , self .num_heads , self .qk_head_dim )
1066- k [..., :self .qk_nope_head_dim ] = k_nope . view ( - 1 , self . num_heads ,
1067- self .qk_nope_head_dim )
1076+ compiled_copy_ ( k [..., :self .qk_nope_head_dim ],
1077+ k_nope . view ( - 1 , self . num_heads , self .qk_nope_head_dim ) )
10681078 if self .apply_rotary_emb :
10691079 k [..., self .qk_nope_head_dim :] = k_pe .view (- 1 , 1 ,
10701080 self .qk_rope_head_dim )
@@ -1122,7 +1132,7 @@ def forward_context_with_cached_kv(
11221132 full_k_nope = full_k_nope .view (- 1 , self .num_heads ,
11231133 self .qk_nope_head_dim )
11241134 full_k_pe = full_k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
1125- full_k = torch . cat (
1135+ full_k = compiled_cat (
11261136 (full_k_nope , full_k_pe .expand (- 1 , self .num_heads , - 1 )), dim = - 1 )
11271137 full_k = full_k .view (- 1 , self .num_heads * self .qk_head_dim )
11281138
@@ -1217,7 +1227,7 @@ def forward_context_with_chunked_prefill(
12171227 chunked_k_nope = chunked_k_nope .view (- 1 , self .num_heads ,
12181228 self .qk_nope_head_dim )
12191229 chunked_k_pe = chunked_k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
1220- chunked_k = torch . cat (
1230+ chunked_k = compiled_cat (
12211231 (chunked_k_nope , chunked_k_pe .expand (- 1 , self .num_heads , - 1 )),
12221232 dim = - 1 )
12231233 chunked_k = chunked_k .view (- 1 , self .num_heads * self .qk_head_dim )
@@ -1275,7 +1285,7 @@ def forward_context_with_chunked_prefill(
12751285
12761286 k_nope = k_nope .view (- 1 , self .num_heads , self .qk_nope_head_dim )
12771287 k_pe = k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
1278- k = torch . cat ((k_nope , k_pe .expand (- 1 , self .num_heads , - 1 )), dim = - 1 )
1288+ k = compiled_cat ((k_nope , k_pe .expand (- 1 , self .num_heads , - 1 )), dim = - 1 )
12791289 k = k .view (- 1 , self .num_heads * self .qk_head_dim )
12801290
12811291 # copy q_lens to replace kv_lens_runtime
0 commit comments