Skip to content

Commit 2392022

Browse files
authored
[#4585][feat] Replace unified attention before export (#8303)
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 32e1ad6 commit 2392022

File tree

4 files changed

+84
-126
lines changed

4 files changed

+84
-126
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Patch for torch.export.export to detect and replace hf attention_interface with unified attention."""
2+
3+
from typing import Optional
4+
5+
import torch
6+
import torch.export as te
7+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
8+
9+
from ..interface import BaseExportPatch, ExportPatchRegistry
10+
11+
# Kwargs mapping for HF attention_interface to auto_deploy::torch_attention
12+
HF_ATTN_KWARGS_MAPPING = {
13+
"dropout": "dropout_p",
14+
"is_causal": "is_causal",
15+
"scaling": "scale",
16+
"scale": "scale",
17+
"s_aux": "sinks",
18+
"sinks": "sinks",
19+
"sliding_window": "sliding_window",
20+
"logit_cap": "logit_cap",
21+
}
22+
23+
24+
def torch_attention_hf_wrapper(
25+
self: torch.nn.Module,
26+
query: torch.Tensor,
27+
key: torch.Tensor,
28+
value: torch.Tensor,
29+
attention_mask: Optional[torch.Tensor],
30+
**kwargs,
31+
):
32+
"""Wrapper of auto_deploy::torch_attention with HF attention_interface signature."""
33+
34+
# Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim]
35+
query_states = query.transpose(1, 2)
36+
key_states = key.transpose(1, 2)
37+
value_states = value.transpose(1, 2)
38+
39+
ad_attn_kwargs = {
40+
HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING
41+
}
42+
43+
attn_output = torch.ops.auto_deploy.torch_attention(
44+
query_states,
45+
key_states,
46+
value_states,
47+
attn_mask=attention_mask,
48+
layout="bsnd",
49+
**ad_attn_kwargs,
50+
)
51+
52+
return attn_output, None
53+
54+
55+
@ExportPatchRegistry.register("unified_attn")
56+
class UnifiedAttnPatch(BaseExportPatch):
57+
"""
58+
Patch on torch.export.export to replace attention_interface with torch.ops.auto_deploy.torch_attention.
59+
"""
60+
61+
def _apply_patch(self):
62+
"""Apply the te.export patch."""
63+
# Store original torch.export.export
64+
self.original_values["te.export"] = te.export
65+
66+
# Register the wrapper function
67+
ALL_ATTENTION_FUNCTIONS.register("ad_unified_attn", torch_attention_hf_wrapper)
68+
69+
def _export_with_unified_attn(model, *args, **kwargs):
70+
# torch_export_to_gm is called at both export stage and attn matching stage
71+
# we only patch attn implementation for export stage
72+
if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
73+
model.config._attn_implementation = "ad_unified_attn"
74+
return self.original_values["te.export"](model, *args, **kwargs)
75+
76+
# Apply patch
77+
te.export = _export_with_unified_attn
78+
79+
def _revert_patch(self):
80+
"""Revert the te.export patch."""
81+
te.export = self.original_values["te.export"]

tensorrt_llm/_torch/auto_deploy/models/patches/decilm.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
1313

1414
from ...custom_ops.attention_interface import AttentionDescriptor, Constant
15+
from ...export.library.unified_attn import HF_ATTN_KWARGS_MAPPING
1516
from ...models.factory import ModelFactory
1617
from ...shim.interface import CachedSequenceInterface
1718
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@@ -39,16 +40,8 @@ def fake_profiler_mha(
3940

4041
# construct kwargs for bsnd_grouped_sdpa
4142
node_kwargs = {"attn_mask": attention_mask, "is_causal": is_causal}
42-
kwargs_to_op = {
43-
"dropout": "dropout_p",
44-
"scaling": "scale",
45-
"scale": "scale",
46-
"s_aux": "sinks",
47-
"sinks": "sinks",
48-
"sliding_window": "sliding_window",
49-
"logit_cap": "logit_cap",
50-
}
51-
for k_kwargs, k_op_kwargs in kwargs_to_op.items():
43+
44+
for k_kwargs, k_op_kwargs in HF_ATTN_KWARGS_MAPPING.items():
5245
if k_kwargs in kwargs:
5346
node_kwargs[k_op_kwargs] = kwargs[k_kwargs]
5447

0 commit comments

Comments
 (0)