Skip to content

Commit 6d6c3a0

Browse files
sxufacebook-github-bot
authored andcommitted
TransformerBlock: support attention skips
Summary: We want to support attention skips. This diff modifies `TransformerBlock` to make `attention_norm` and `attention` optional. Since our export script directly constructs the `TransformerBlock`s themselves, this is enough for our use case. The top level `Transformer` class still require a single `attention_type`, to make that interface also support attention skip (which requires different configuration for each layer) is not within the scope of this diff. Differential Revision: D84003431
1 parent a39866c commit 6d6c3a0

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7171

7272

7373
class TransformerBlock(nn.Module):
74-
def __init__(self, args: ModelArgs, attention: Attention):
74+
def __init__(self, args: ModelArgs, attention: Optional[Attention]):
7575
"""
7676
Transformer block with support for pre-norm and post-norm.
7777
Args:
@@ -95,7 +95,8 @@ def __init__(self, args: ModelArgs, attention: Attention):
9595
else:
9696
self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim)
9797

98-
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
98+
if self.attention is not None:
99+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
99100
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
100101

101102
@classmethod
@@ -107,21 +108,28 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
107108
args (ModelArgs): model configuration parameters.
108109
rope (Rope): the rope object to use for rotary embeddings.
109110
"""
110-
if args.attention_type not in ATTENTION_REGISTRY:
111-
raise ValueError(
112-
f"Unknown attention type: {args.attention_type}. "
113-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
114-
)
115-
cls = ATTENTION_REGISTRY[args.attention_type]
116-
attention = cls(args, layer_id, rope, **args.attention_kwargs)
111+
if args.attention_type is not None:
112+
if args.attention_type not in ATTENTION_REGISTRY:
113+
raise ValueError(
114+
f"Unknown attention type: {args.attention_type}. "
115+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
116+
)
117+
cls = ATTENTION_REGISTRY[args.attention_type]
118+
attention = cls(args, layer_id, rope, **args.attention_kwargs)
119+
else:
120+
attention = None
117121
return TransformerBlock(args, attention)
118122

119123
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
120-
h, attn_options_update = self.attention.forward(
121-
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
122-
)
124+
if self.attention is not None:
125+
h, attn_options_update = self.attention.forward(
126+
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
127+
)
128+
h = x + h
129+
else:
130+
h = x
131+
attn_options_update = attn_options
123132

124-
h = x + h
125133
if hasattr(self, "block_sparse_moe"):
126134
out = h + self.block_sparse_moe(self.ffn_norm(h))
127135
else:

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class ModelArgs:
5555
moe: bool = False # True to enable the MoE (Mixture of Experts)
5656
num_experts: int = 8 # Number of experts
5757
num_activated_experts: int = 2 # Number of experts to activate
58-
attention_type: str = "mha" # Attention type, registered in attention.py
58+
attention_type: Optional[str] = "mha" # Attention type, registered in attention.py
5959
norm_type: str = "rmsnorm" # Normalization type, registered in norm.py
6060
act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type
6161
attention_qkv_bias: bool = False

0 commit comments

Comments
 (0)