Skip to content

Commit 2672dd3

Browse files
authored
TransformerBlock: support attention skips
Differential Revision: D84003431 Pull Request resolved: #14826
1 parent 1da530d commit 2672dd3

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

examples/models/llama/attention.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,18 @@ def forward(
516516
output = self.wo(output)
517517

518518
return output, None
519+
520+
521+
@register_attention("skip")
522+
class AttentionSkip(Attention):
523+
def __init__(self, *args, **kwargs):
524+
super().__init__()
525+
526+
def forward(
527+
self,
528+
x: torch.Tensor,
529+
freqs_cos: torch.Tensor,
530+
freqs_sin: torch.Tensor,
531+
**kwargs: ForwardOptions,
532+
) -> Tuple[torch.Tensor, Optional[Any]]:
533+
return x, None

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from executorch.examples.models.llama.attention import (
1616
Attention,
1717
ATTENTION_REGISTRY,
18+
AttentionSkip,
1819
ForwardOptions,
1920
)
2021
from executorch.examples.models.llama.feed_forward import FeedForward
@@ -95,7 +96,10 @@ def __init__(self, args: ModelArgs, attention: Attention):
9596
else:
9697
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
9798

98-
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
99+
if isinstance(self.attention, AttentionSkip):
100+
self.attention_norm = nn.Identity()
101+
else:
102+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
99103
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
100104

101105
@classmethod
@@ -120,8 +124,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
120124
h, attn_options_update = self.attention.forward(
121125
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
122126
)
127+
if not isinstance(self.attention, AttentionSkip):
128+
h = x + h
123129

124-
h = x + h
125130
if hasattr(self, "block_sparse_moe"):
126131
out = h + self.block_sparse_moe(self.ffn_norm(h))
127132
else:

0 commit comments

Comments
 (0)