Skip to content

Commit 5346f53

Browse files
authored
feat: Introduce feature properties for attention backend. (NVIDIA#3659)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 61ee983 commit 5346f53

File tree

4 files changed

+51
-22
lines changed

4 files changed

+51
-22
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,18 @@ def forward(self,
537537
"""
538538
raise NotImplementedError
539539

540+
@classmethod
541+
def support_fused_rope(cls) -> bool:
542+
return False
543+
544+
@classmethod
545+
def support_fused_qkv(cls) -> bool:
546+
return False
547+
548+
@classmethod
549+
def support_mla(cls) -> bool:
550+
return False
551+
540552

541553
@dataclass(kw_only=True, unsafe_hash=True)
542554
class MLAParams:

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,15 @@ def forward(
710710
or k is not None,
711711
attention_mask=attention_mask)
712712
return output
713+
714+
@classmethod
715+
def support_fused_rope(cls) -> bool:
716+
return True
717+
718+
@classmethod
719+
def support_fused_qkv(cls) -> bool:
720+
return True
721+
722+
@classmethod
723+
def support_mla(cls) -> bool:
724+
return True

tensorrt_llm/_torch/attention_backend/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def create_attention(
4545
attn_cls = get_attention_backend(backend_name)
4646

4747
if is_mla_enable:
48-
assert attn_cls == TrtllmAttention
48+
assert attn_cls.support_mla(
49+
), f"MLA is not supported for {backend_name} backend"
4950
assert (q_lora_rank > 0 and kv_lora_rank > 0 and qk_rope_head_dim > 0
5051
and qk_nope_head_dim > 0 and v_head_dim > 0)
5152
mla_params = MLAParams(

tensorrt_llm/_torch/modules/attention.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from tensorrt_llm.mapping import Mapping
88

9-
from ..attention_backend import (AttentionInputType, AttentionMetadata,
10-
TrtllmAttention)
9+
from ..attention_backend import AttentionInputType, AttentionMetadata
1110
from ..attention_backend.interface import (PositionalEmbeddingParams,
1211
PredefinedAttentionMask)
1312
from ..attention_backend.utils import create_attention
@@ -104,19 +103,6 @@ def __init__(
104103
self.attn_backend = config.attn_backend
105104
self.pos_embd_params = pos_embd_params
106105

107-
self.enable_rope_fusion = self.attn_backend == "TRTLLM"
108-
self.support_fused_qkv = self.attn_backend == "TRTLLM"
109-
self.support_unfused_qkv = self.attn_backend != "TRTLLM"
110-
self.rotary_emb = None
111-
self.apply_rotary_emb = (not self.enable_rope_fusion
112-
and pos_embd_params is not None)
113-
if self.apply_rotary_emb:
114-
self.rotary_emb = RotaryEmbedding(
115-
pos_embd_params.rope,
116-
head_dim=self.head_dim,
117-
is_neox=pos_embd_params.is_neox,
118-
)
119-
120106
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
121107
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
122108
# handles them as a single fused operation.
@@ -132,8 +118,23 @@ def __init__(
132118

133119
if not config.skip_create_weights:
134120
self.create_weights()
121+
else:
122+
self.create_backend()
135123

136-
def create_weights(self):
124+
self.enable_rope_fusion = self.attn.support_fused_rope()
125+
self.support_fused_qkv = self.attn.support_fused_qkv()
126+
127+
self.rotary_emb = None
128+
self.apply_rotary_emb = (not self.enable_rope_fusion
129+
and pos_embd_params is not None)
130+
if self.apply_rotary_emb:
131+
self.rotary_emb = RotaryEmbedding(
132+
pos_embd_params.rope,
133+
head_dim=self.head_dim,
134+
is_neox=pos_embd_params.is_neox,
135+
)
136+
137+
def create_backend(self):
137138
self.attn = create_attention(
138139
self.attn_backend,
139140
self.layer_idx,
@@ -144,10 +145,14 @@ def create_weights(self):
144145
quant_config=self.quant_config,
145146
)
146147

148+
def create_weights(self):
149+
# recreate the backend when quant_config changes
150+
self.create_backend()
151+
147152
def convert_qkv(self, q, k, v):
148153
if k is None and v is None and not self.support_fused_qkv:
149154
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
150-
elif k is not None and v is not None and not self.support_unfused_qkv:
155+
elif k is not None and v is not None and self.support_fused_qkv:
151156
qkv = torch.concat([q, k, v], dim=-1)
152157
q, k, v = qkv, None, None
153158
return q, k, v
@@ -459,9 +464,8 @@ def yarn_get_mscale(scale=1, mscale=1):
459464
self.aux_stream = aux_stream
460465
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
461466

462-
self.enable_rope_fusion = isinstance(self.mha, TrtllmAttention)
463-
self.support_fused_qkv = isinstance(self.mha, TrtllmAttention)
464-
self.support_unfused_qkv = not isinstance(self.mha, TrtllmAttention)
467+
self.enable_rope_fusion = self.mha.support_fused_rope()
468+
self.support_fused_qkv = self.mha.support_fused_qkv()
465469
self.rotary_emb = None
466470
self.apply_rotary_emb = not self.enable_rope_fusion
467471
if self.apply_rotary_emb:
@@ -575,7 +579,7 @@ def forward(
575579
return attn_output
576580

577581
def _maybe_concat_qkv(self, q, k, v):
578-
if k is not None and v is not None and not self.support_unfused_qkv:
582+
if k is not None and v is not None and self.support_fused_qkv:
579583
qkv = torch.concat([q, k, v], dim=-1)
580584
q, k, v = qkv, None, None
581585
return q, k, v

0 commit comments

Comments
 (0)