Skip to content

Commit 2aa31f9

Browse files
add parallel_residual setting to gptneox (#586)
Co-authored-by: Alexander Schwirjow <[email protected]>
1 parent 4decc4d commit 2aa31f9

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

mlx_lm/models/gpt_neox.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ModelArgs(BaseModelArgs):
2323
vocab_size: int
2424
rotary_emb_base: int
2525
rotary_pct: float
26+
use_parallel_residual: bool = True
2627
num_key_value_heads: int = None
2728

2829
def __post_init__(self):
@@ -107,6 +108,7 @@ def __init__(self, args: ModelArgs):
107108
self.layer_norm_eps = args.layer_norm_eps
108109
self.attention = Attention(args)
109110
self.mlp = MLP(args)
111+
self.use_parallel_residual = args.use_parallel_residual
110112
self.input_layernorm = nn.LayerNorm(
111113
self.hidden_size,
112114
eps=self.layer_norm_eps,
@@ -121,12 +123,20 @@ def __call__(
121123
mask: Optional[mx.array] = None,
122124
cache: Optional[Any] = None,
123125
) -> mx.array:
124-
residual = x
125-
# NeoX runs attention and feedforward network in parallel.
126-
attn = self.attention(self.input_layernorm(x), mask, cache)
127-
ffn = self.mlp(self.post_attention_layernorm(x))
128-
out = attn + ffn + residual
129-
return out
126+
if self.use_parallel_residual:
127+
residual = x
128+
# Run attention and feedforward network in parallel.
129+
attn = self.attention(self.input_layernorm(x), mask, cache)
130+
ffn = self.mlp(self.post_attention_layernorm(x))
131+
out = attn + ffn + residual
132+
return out
133+
else:
134+
# Run attention and feedforward network sequentially.
135+
attn_output = self.attention(self.input_layernorm(x), mask, cache)
136+
x = x + attn_output
137+
ffn_output = self.mlp(self.post_attention_layernorm(x))
138+
x = x + ffn_output
139+
return x
130140

131141

132142
class GPTNeoXModel(nn.Module):

0 commit comments

Comments
 (0)