Skip to content

Commit 99767f9

Browse files
sxufacebook-github-bot
authored andcommitted
TransformerBlock: support attention skips (#14826)
Summary: We want to support attention skips. This diff modifies `TransformerBlock` to make `attention_norm` and `attention` optional. Since our export script directly constructs the `TransformerBlock`s themselves, this is enough for our use case. The top level `Transformer` class still require a single `attention_type`, to make that interface also support attention skip (which requires different configuration for each layer) is not within the scope of this diff. Reviewed By: billmguo Differential Revision: D84003431
1 parent 0e74a17 commit 99767f9

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

examples/models/llama/attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,15 @@ def forward(
516516
output = self.wo(output)
517517

518518
return output, None
519+
520+
521+
@register_attention("skip")
522+
class AttentionSkip(Attention):
523+
def forward(
524+
self,
525+
x: torch.Tensor,
526+
_freqs_cos: torch.Tensor,
527+
_freqs_sin: torch.Tensor,
528+
**_kwargs: ForwardOptions,
529+
) -> Tuple[torch.Tensor, Optional[Any]]:
530+
return x, None

examples/models/llama/llama_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from executorch.examples.models.llama.attention import (
1616
Attention,
1717
ATTENTION_REGISTRY,
18+
AttentionSkip,
1819
ForwardOptions,
1920
)
2021
from executorch.examples.models.llama.feed_forward import FeedForward
@@ -95,7 +96,8 @@ def __init__(self, args: ModelArgs, attention: Attention):
9596
else:
9697
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
9798

98-
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
99+
if isinstance(self.attention, AttentionSkip):
100+
self.attention_norm = nn.Identity()
99101
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
100102

101103
@classmethod
@@ -120,8 +122,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
120122
h, attn_options_update = self.attention.forward(
121123
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
122124
)
125+
if not isinstance(self.attention, AttentionSkip):
126+
h = x + h
123127

124-
h = x + h
125128
if hasattr(self, "block_sparse_moe"):
126129
out = h + self.block_sparse_moe(self.ffn_norm(h))
127130
else:

0 commit comments

Comments
 (0)