@@ -23,6 +23,7 @@ class ModelArgs(BaseModelArgs):
2323 vocab_size : int
2424 rotary_emb_base : int
2525 rotary_pct : float
26+ use_parallel_residual : bool = True
2627 num_key_value_heads : int = None
2728
2829 def __post_init__ (self ):
@@ -107,6 +108,7 @@ def __init__(self, args: ModelArgs):
107108 self .layer_norm_eps = args .layer_norm_eps
108109 self .attention = Attention (args )
109110 self .mlp = MLP (args )
111+ self .use_parallel_residual = args .use_parallel_residual
110112 self .input_layernorm = nn .LayerNorm (
111113 self .hidden_size ,
112114 eps = self .layer_norm_eps ,
@@ -121,12 +123,20 @@ def __call__(
121123 mask : Optional [mx .array ] = None ,
122124 cache : Optional [Any ] = None ,
123125 ) -> mx .array :
124- residual = x
125- # NeoX runs attention and feedforward network in parallel.
126- attn = self .attention (self .input_layernorm (x ), mask , cache )
127- ffn = self .mlp (self .post_attention_layernorm (x ))
128- out = attn + ffn + residual
129- return out
126+ if self .use_parallel_residual :
127+ residual = x
128+ # Run attention and feedforward network in parallel.
129+ attn = self .attention (self .input_layernorm (x ), mask , cache )
130+ ffn = self .mlp (self .post_attention_layernorm (x ))
131+ out = attn + ffn + residual
132+ return out
133+ else :
134+ # Run attention and feedforward network sequentially.
135+ attn_output = self .attention (self .input_layernorm (x ), mask , cache )
136+ x = x + attn_output
137+ ffn_output = self .mlp (self .post_attention_layernorm (x ))
138+ x = x + ffn_output
139+ return x
130140
131141
132142class GPTNeoXModel (nn .Module ):
0 commit comments