diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 6e3f7cb9fb2..0c0176269b3 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -516,3 +516,18 @@ def forward( output = self.wo(output) return output, None + + +@register_attention("skip") +class AttentionSkip(Attention): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + return x, None diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3a325d0f4f8..6587f7e1a10 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -15,6 +15,7 @@ from executorch.examples.models.llama.attention import ( Attention, ATTENTION_REGISTRY, + AttentionSkip, ForwardOptions, ) from executorch.examples.models.llama.feed_forward import FeedForward @@ -95,7 +96,10 @@ def __init__(self, args: ModelArgs, attention: Attention): else: self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + if isinstance(self.attention, AttentionSkip): + self.attention_norm = nn.Identity() + else: + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) @classmethod @@ -120,8 +124,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: h, attn_options_update = self.attention.forward( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) + if not isinstance(self.attention, AttentionSkip): + h = x + h - h = x + h if hasattr(self, "block_sparse_moe"): out = h + self.block_sparse_moe(self.ffn_norm(h)) else: