66
77from tensorrt_llm .mapping import Mapping
88
9- from ..attention_backend import (AttentionInputType , AttentionMetadata ,
10- TrtllmAttention )
9+ from ..attention_backend import AttentionInputType , AttentionMetadata
1110from ..attention_backend .interface import (PositionalEmbeddingParams ,
1211 PredefinedAttentionMask )
1312from ..attention_backend .utils import create_attention
@@ -104,19 +103,6 @@ def __init__(
104103 self .attn_backend = config .attn_backend
105104 self .pos_embd_params = pos_embd_params
106105
107- self .enable_rope_fusion = self .attn_backend == "TRTLLM"
108- self .support_fused_qkv = self .attn_backend == "TRTLLM"
109- self .support_unfused_qkv = self .attn_backend != "TRTLLM"
110- self .rotary_emb = None
111- self .apply_rotary_emb = (not self .enable_rope_fusion
112- and pos_embd_params is not None )
113- if self .apply_rotary_emb :
114- self .rotary_emb = RotaryEmbedding (
115- pos_embd_params .rope ,
116- head_dim = self .head_dim ,
117- is_neox = pos_embd_params .is_neox ,
118- )
119-
120106 # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
121107 # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
122108 # handles them as a single fused operation.
@@ -132,8 +118,23 @@ def __init__(
132118
133119 if not config .skip_create_weights :
134120 self .create_weights ()
121+ else :
122+ self .create_backend ()
135123
136- def create_weights (self ):
124+ self .enable_rope_fusion = self .attn .support_fused_rope ()
125+ self .support_fused_qkv = self .attn .support_fused_qkv ()
126+
127+ self .rotary_emb = None
128+ self .apply_rotary_emb = (not self .enable_rope_fusion
129+ and pos_embd_params is not None )
130+ if self .apply_rotary_emb :
131+ self .rotary_emb = RotaryEmbedding (
132+ pos_embd_params .rope ,
133+ head_dim = self .head_dim ,
134+ is_neox = pos_embd_params .is_neox ,
135+ )
136+
137+ def create_backend (self ):
137138 self .attn = create_attention (
138139 self .attn_backend ,
139140 self .layer_idx ,
@@ -144,10 +145,14 @@ def create_weights(self):
144145 quant_config = self .quant_config ,
145146 )
146147
148+ def create_weights (self ):
149+ # recreate the backend when quant_config changes
150+ self .create_backend ()
151+
147152 def convert_qkv (self , q , k , v ):
148153 if k is None and v is None and not self .support_fused_qkv :
149154 q , k , v = q .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
150- elif k is not None and v is not None and not self .support_unfused_qkv :
155+ elif k is not None and v is not None and self .support_fused_qkv :
151156 qkv = torch .concat ([q , k , v ], dim = - 1 )
152157 q , k , v = qkv , None , None
153158 return q , k , v
@@ -459,9 +464,8 @@ def yarn_get_mscale(scale=1, mscale=1):
459464 self .aux_stream = aux_stream
460465 self .ln_events = [torch .cuda .Event (), torch .cuda .Event ()]
461466
462- self .enable_rope_fusion = isinstance (self .mha , TrtllmAttention )
463- self .support_fused_qkv = isinstance (self .mha , TrtllmAttention )
464- self .support_unfused_qkv = not isinstance (self .mha , TrtllmAttention )
467+ self .enable_rope_fusion = self .mha .support_fused_rope ()
468+ self .support_fused_qkv = self .mha .support_fused_qkv ()
465469 self .rotary_emb = None
466470 self .apply_rotary_emb = not self .enable_rope_fusion
467471 if self .apply_rotary_emb :
@@ -575,7 +579,7 @@ def forward(
575579 return attn_output
576580
577581 def _maybe_concat_qkv (self , q , k , v ):
578- if k is not None and v is not None and not self .support_unfused_qkv :
582+ if k is not None and v is not None and self .support_fused_qkv :
579583 qkv = torch .concat ([q , k , v ], dim = - 1 )
580584 q , k , v = qkv , None , None
581585 return q , k , v
0 commit comments