From 976b081a1fc18a39c45376e451a98200ae79d0d0 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Wed, 8 Oct 2025 07:44:48 -0700 Subject: [PATCH] TransformerBlock: support attention skips (#14826) Summary: We want to support attention skips. This diff modifies `TransformerBlock` to make `attention_norm` and `attention` optional. Since our export script directly constructs the `TransformerBlock`s themselves, this is enough for our use case. The top level `Transformer` class still require a single `attention_type`, to make that interface also support attention skip (which requires different configuration for each layer) is not within the scope of this diff. Reviewed By: billmguo Differential Revision: D84003431 --- examples/models/llama/attention.py | 15 +++++++++++++++ examples/models/llama/llama_transformer.py | 9 +++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 6e3f7cb9fb2..0c0176269b3 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -516,3 +516,18 @@ def forward( output = self.wo(output) return output, None + + +@register_attention("skip") +class AttentionSkip(Attention): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + return x, None diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3a325d0f4f8..6587f7e1a10 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -15,6 +15,7 @@ from executorch.examples.models.llama.attention import ( Attention, ATTENTION_REGISTRY, + AttentionSkip, ForwardOptions, ) from executorch.examples.models.llama.feed_forward import FeedForward @@ -95,7 +96,10 @@ def __init__(self, args: ModelArgs, attention: Attention): else: self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + if isinstance(self.attention, AttentionSkip): + self.attention_norm = nn.Identity() + else: + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) @classmethod @@ -120,8 +124,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: h, attn_options_update = self.attention.forward( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) + if not isinstance(self.attention, AttentionSkip): + h = x + h - h = x + h if hasattr(self, "block_sparse_moe"): out = h + self.block_sparse_moe(self.ffn_norm(h)) else: