55import torch
66from torch import nn
77
8- from tensorrt_llm ._utils import get_sm_version
8+ from tensorrt_llm ._utils import get_sm_version , is_sm_100f
99from tensorrt_llm .logger import logger
1010from tensorrt_llm .mapping import Mapping
1111
2424from .linear import Linear , TensorParallelMode , WeightMode , WeightsLoadingConfig
2525from .multi_stream_utils import maybe_execute_in_parallel
2626from .rms_norm import RMSNorm
27- from .rotary_embedding import RotaryEmbedding
27+ from .rotary_embedding import MRotaryEmbedding , RotaryEmbedding
2828
2929
3030def extract_extra_attrs (layer_idx : str , attn_type : str ):
@@ -67,6 +67,16 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
6767 return metadata , attn_layer
6868
6969
70+ @torch .compile
71+ def compiled_copy_ (dst , src ):
72+ dst .copy_ (src )
73+
74+
75+ @torch .compile
76+ def compiled_cat (tensors , dim ):
77+ return torch .cat (tensors , dim )
78+
79+
7080@torch .library .custom_op ("trtllm::attn_custom_op_inplace" ,
7181 mutates_args = ("output" , ))
7282def attn_custom_op_inplace (
@@ -271,11 +281,19 @@ def __init__(
271281
272282 self .rotary_emb = None
273283 if not self .rope_fusion and self .pos_embd_params is not None :
274- self .rotary_emb = RotaryEmbedding (
275- self .pos_embd_params .rope ,
276- head_dim = self .head_dim ,
277- is_neox = self .pos_embd_params .is_neox ,
278- )
284+ if self .pos_embd_params .type .is_mrope ():
285+ self .rotary_emb = MRotaryEmbedding (
286+ self .pos_embd_params .rope ,
287+ head_dim = self .head_dim ,
288+ is_neox = self .pos_embd_params .is_neox ,
289+ mrope_section = self .pos_embd_params .mrope_section ,
290+ )
291+ else :
292+ self .rotary_emb = RotaryEmbedding (
293+ self .pos_embd_params .rope ,
294+ head_dim = self .head_dim ,
295+ is_neox = self .pos_embd_params .is_neox ,
296+ )
279297
280298 self .attn = create_attention (
281299 self .attn_backend ,
@@ -301,6 +319,12 @@ def create_weights(self):
301319 # which could be modified after __init__
302320 self .attn .update_quant_config (self .quant_config )
303321
322+ self .o_proj .create_weights ()
323+ self .has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
324+ or self .o_proj .has_fp8_block_scales
325+ or self .o_proj .has_fp8_rowwise
326+ or self .o_proj .has_w4a8_nvfp4_fp8 )
327+
304328 def split_qkv (self , q , k = None , v = None ):
305329 if k is None and v is None :
306330 q , k , v = q .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
@@ -320,12 +344,8 @@ def create_output(self, q: torch.Tensor):
320344 out_dtype = q .dtype
321345
322346 if self .attn_backend == "TRTLLM" :
323- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
324- or self .o_proj .has_fp8_block_scales
325- or self .o_proj .has_fp8_rowwise
326- or self .o_proj .has_w4a8_nvfp4_fp8 )
327- if has_quant_scale and (self .attn .has_fp8_kv_cache
328- or self .attn .has_fp4_kv_cache ):
347+ if self .has_quant_scale and (self .attn .has_fp8_kv_cache
348+ or self .attn .has_fp4_kv_cache ):
329349 out_dtype = torch .float8_e4m3fn
330350 output = q .new_empty ([num_tokens , hidden_size ], dtype = out_dtype )
331351 return output
@@ -356,11 +376,7 @@ def _attn_impl(
356376
357377 out_scale = None
358378 out_scale_sf = None
359- has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
360- or self .o_proj .has_fp8_block_scales
361- or self .o_proj .has_fp8_rowwise
362- or self .o_proj .has_w4a8_nvfp4_fp8 )
363- if has_quant_scale :
379+ if self .has_quant_scale :
364380 out_scale = self .o_proj .inv_input_scale
365381 if self .o_proj .has_nvfp4 and self .support_nvfp4_output and enable_attn_nvfp4_output :
366382 out_scale_sf = self .o_proj .input_scale
@@ -585,7 +601,7 @@ def fp8_block_scaling_bmm_out(
585601 output )
586602 out .copy_ (output )
587603
588- elif sm_version == 100 :
604+ elif is_sm_100f ( sm_version ) :
589605 torch .bmm (mat1 .transpose (0 , 1 ), mat2_dequant .transpose (1 , 2 ), out = out )
590606 else :
591607 raise NotImplementedError (f"SM{ sm_version } is not supported" )
@@ -858,6 +874,9 @@ def create_weights(self):
858874 self .mha .update_quant_config (self .quant_config )
859875 self .mqa .update_quant_config (self .quant_config )
860876
877+ # Although we use FP8 MLA for context/generation phase, the output is still in BF16
878+ self .out_scale = None
879+
861880 # k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
862881 # which can be modified after __init__
863882 has_fp8_block_scales = (
@@ -900,7 +919,7 @@ def create_weights(self):
900919 ),
901920 requires_grad = False ,
902921 )
903- if get_sm_version () == 100 :
922+ if is_sm_100f () :
904923 assert self .dtype == torch .bfloat16
905924 self .k_b_proj_trans_dequant = nn .Parameter (
906925 torch .empty (
@@ -1054,24 +1073,21 @@ def forward_context_default(
10541073 )
10551074
10561075 k = torch .empty_like (q ).view (- 1 , self .num_heads , self .qk_head_dim )
1057- k [..., :self .qk_nope_head_dim ] = k_nope . view ( - 1 , self . num_heads ,
1058- self .qk_nope_head_dim )
1076+ compiled_copy_ ( k [..., :self .qk_nope_head_dim ],
1077+ k_nope . view ( - 1 , self . num_heads , self .qk_nope_head_dim ) )
10591078 if self .apply_rotary_emb :
10601079 k [..., self .qk_nope_head_dim :] = k_pe .view (- 1 , 1 ,
10611080 self .qk_rope_head_dim )
10621081 k = k .view (- 1 , self .num_heads * self .qk_head_dim )
10631082
1064- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1065- out_scale = None # Currently we use BF16 MHA for context phase
1066-
10671083 attn_output = self .mha .forward (
10681084 q ,
10691085 k ,
10701086 v ,
10711087 attn_metadata ,
10721088 attention_input_type = AttentionInputType .context_only ,
10731089 latent_cache = latent_cache ,
1074- out_scale = out_scale ,
1090+ out_scale = self . out_scale ,
10751091 output = output ,
10761092 )
10771093
@@ -1116,7 +1132,7 @@ def forward_context_with_cached_kv(
11161132 full_k_nope = full_k_nope .view (- 1 , self .num_heads ,
11171133 self .qk_nope_head_dim )
11181134 full_k_pe = full_k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
1119- full_k = torch . cat (
1135+ full_k = compiled_cat (
11201136 (full_k_nope , full_k_pe .expand (- 1 , self .num_heads , - 1 )), dim = - 1 )
11211137 full_k = full_k .view (- 1 , self .num_heads * self .qk_head_dim )
11221138
@@ -1126,9 +1142,6 @@ def forward_context_with_cached_kv(
11261142 full_kv = None
11271143 full_k_nope = None
11281144
1129- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1130- out_scale = None # Currently we use BF16 MHA for context phase
1131-
11321145 # latent_cache must be None to differentiate from normal context phase,
11331146 # so that we can skip applying RoPE and appending KV cache inside attention op
11341147 attn_output = self .mha .forward (
@@ -1138,7 +1151,7 @@ def forward_context_with_cached_kv(
11381151 attn_metadata ,
11391152 attention_input_type = AttentionInputType .context_only ,
11401153 latent_cache = None ,
1141- out_scale = out_scale ,
1154+ out_scale = self . out_scale ,
11421155 output = output ,
11431156 )
11441157
@@ -1214,7 +1227,7 @@ def forward_context_with_chunked_prefill(
12141227 chunked_k_nope = chunked_k_nope .view (- 1 , self .num_heads ,
12151228 self .qk_nope_head_dim )
12161229 chunked_k_pe = chunked_k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
1217- chunked_k = torch . cat (
1230+ chunked_k = compiled_cat (
12181231 (chunked_k_nope , chunked_k_pe .expand (- 1 , self .num_heads , - 1 )),
12191232 dim = - 1 )
12201233 chunked_k = chunked_k .view (- 1 , self .num_heads * self .qk_head_dim )
@@ -1232,7 +1245,6 @@ def forward_context_with_chunked_prefill(
12321245 loop_idx ]
12331246 attn_metadata .host_total_kv_lens [0 ] = total_ctx_chunked_tokens
12341247
1235- out_scale = None
12361248 # do not apply mask for attention within loop
12371249 # latent_cache must be None to differentiate from normal context phase,
12381250 # so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1243,7 +1255,7 @@ def forward_context_with_chunked_prefill(
12431255 attn_metadata ,
12441256 attention_input_type = AttentionInputType .context_only ,
12451257 latent_cache = None ,
1246- out_scale = out_scale ,
1258+ out_scale = self . out_scale ,
12471259 attention_mask = PredefinedAttentionMask .FULL ,
12481260 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
12491261 chunked_prefill_buffer_batch_size = attn_metadata .
@@ -1273,7 +1285,7 @@ def forward_context_with_chunked_prefill(
12731285
12741286 k_nope = k_nope .view (- 1 , self .num_heads , self .qk_nope_head_dim )
12751287 k_pe = k_pe .view (- 1 , 1 , self .qk_rope_head_dim )
1276- k = torch . cat ((k_nope , k_pe .expand (- 1 , self .num_heads , - 1 )), dim = - 1 )
1288+ k = compiled_cat ((k_nope , k_pe .expand (- 1 , self .num_heads , - 1 )), dim = - 1 )
12771289 k = k .view (- 1 , self .num_heads * self .qk_head_dim )
12781290
12791291 # copy q_lens to replace kv_lens_runtime
@@ -1284,9 +1296,6 @@ def forward_context_with_chunked_prefill(
12841296 num_contexts ].sum ().item (
12851297 )
12861298
1287- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1288- out_scale = None # Currently we use BF16 MHA for context phase
1289-
12901299 # latent_cache must be None to differentiate from normal context phase,
12911300 # so that we can skip applying RoPE and appending KV cache inside attention op
12921301 temp_attn_output = self .mha .forward (
@@ -1296,7 +1305,7 @@ def forward_context_with_chunked_prefill(
12961305 attn_metadata ,
12971306 attention_input_type = AttentionInputType .context_only ,
12981307 latent_cache = None ,
1299- out_scale = out_scale ,
1308+ out_scale = self . out_scale ,
13001309 softmax_stats_tensor = self .temp_softmax_stats_tensor ,
13011310 chunked_prefill_buffer_batch_size = attn_metadata .runtime_features .
13021311 chunked_prefill_buffer_batch_size ,
@@ -1394,16 +1403,13 @@ def forward_generation(
13941403 self .num_heads * (self .kv_lora_rank + self .qk_rope_head_dim )
13951404 ])
13961405
1397- # out_scale = getattr(self.o_proj, "inv_input_scale", None)
1398- out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1399-
14001406 attn_out_latent = self .mqa .forward (
14011407 fused_q ,
14021408 None ,
14031409 None ,
14041410 attn_metadata ,
14051411 attention_input_type = AttentionInputType .generation_only ,
1406- out_scale = out_scale ,
1412+ out_scale = self . out_scale ,
14071413 latent_cache = latent_cache , # kvcache and k_pe
14081414 q_pe = q_pe , # used by `invokeMLARopeGeneration`
14091415 )
0 commit comments