|
| 1 | +"""Patch for torch.export.export to detect and replace hf attention_interface with unified attention.""" |
| 2 | + |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.export as te |
| 7 | +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| 8 | + |
| 9 | +from ..interface import BaseExportPatch, ExportPatchRegistry |
| 10 | + |
| 11 | +# Kwargs mapping for HF attention_interface to auto_deploy::torch_attention |
| 12 | +HF_ATTN_KWARGS_MAPPING = { |
| 13 | + "dropout": "dropout_p", |
| 14 | + "is_causal": "is_causal", |
| 15 | + "scaling": "scale", |
| 16 | + "scale": "scale", |
| 17 | + "s_aux": "sinks", |
| 18 | + "sinks": "sinks", |
| 19 | + "sliding_window": "sliding_window", |
| 20 | + "logit_cap": "logit_cap", |
| 21 | +} |
| 22 | + |
| 23 | + |
| 24 | +def torch_attention_hf_wrapper( |
| 25 | + self: torch.nn.Module, |
| 26 | + query: torch.Tensor, |
| 27 | + key: torch.Tensor, |
| 28 | + value: torch.Tensor, |
| 29 | + attention_mask: Optional[torch.Tensor], |
| 30 | + **kwargs, |
| 31 | +): |
| 32 | + """Wrapper of auto_deploy::torch_attention with HF attention_interface signature.""" |
| 33 | + |
| 34 | + # Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim] |
| 35 | + query_states = query.transpose(1, 2) |
| 36 | + key_states = key.transpose(1, 2) |
| 37 | + value_states = value.transpose(1, 2) |
| 38 | + |
| 39 | + ad_attn_kwargs = { |
| 40 | + HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING |
| 41 | + } |
| 42 | + |
| 43 | + attn_output = torch.ops.auto_deploy.torch_attention( |
| 44 | + query_states, |
| 45 | + key_states, |
| 46 | + value_states, |
| 47 | + attn_mask=attention_mask, |
| 48 | + layout="bsnd", |
| 49 | + **ad_attn_kwargs, |
| 50 | + ) |
| 51 | + |
| 52 | + return attn_output, None |
| 53 | + |
| 54 | + |
| 55 | +@ExportPatchRegistry.register("unified_attn") |
| 56 | +class UnifiedAttnPatch(BaseExportPatch): |
| 57 | + """ |
| 58 | + Patch on torch.export.export to replace attention_interface with torch.ops.auto_deploy.torch_attention. |
| 59 | + """ |
| 60 | + |
| 61 | + def _apply_patch(self): |
| 62 | + """Apply the te.export patch.""" |
| 63 | + # Store original torch.export.export |
| 64 | + self.original_values["te.export"] = te.export |
| 65 | + |
| 66 | + # Register the wrapper function |
| 67 | + ALL_ATTENTION_FUNCTIONS.register("ad_unified_attn", torch_attention_hf_wrapper) |
| 68 | + |
| 69 | + def _export_with_unified_attn(model, *args, **kwargs): |
| 70 | + # torch_export_to_gm is called at both export stage and attn matching stage |
| 71 | + # we only patch attn implementation for export stage |
| 72 | + if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"): |
| 73 | + model.config._attn_implementation = "ad_unified_attn" |
| 74 | + return self.original_values["te.export"](model, *args, **kwargs) |
| 75 | + |
| 76 | + # Apply patch |
| 77 | + te.export = _export_with_unified_attn |
| 78 | + |
| 79 | + def _revert_patch(self): |
| 80 | + """Revert the te.export patch.""" |
| 81 | + te.export = self.original_values["te.export"] |
0 commit comments