@@ -88,11 +88,17 @@ def __init__(
8888        self .norm3  =  RMSNorm (dim , eps = eps , elementwise_affine = False )
8989        self .norm3_context  =  RMSNorm (pooled_projection_dim , eps = eps , elementwise_affine = False )
9090
91-         self .ff  =  FeedForward (dim , inner_dim = self .ff_inner_dim , activation_fn = activation_fn , bias = False , flip_gate = True )
91+         self .ff  =  FeedForward (
92+             dim , inner_dim = self .ff_inner_dim , activation_fn = activation_fn , bias = False , flip_gate = True 
93+         )
9294        self .ff_context  =  None 
9395        if  not  context_pre_only :
9496            self .ff_context  =  FeedForward (
95-                 pooled_projection_dim , inner_dim = self .ff_context_inner_dim , activation_fn = activation_fn , bias = False , flip_gate = True 
97+                 pooled_projection_dim ,
98+                 inner_dim = self .ff_context_inner_dim ,
99+                 activation_fn = activation_fn ,
100+                 bias = False ,
101+                 flip_gate = True ,
96102            )
97103
98104        self .norm4  =  RMSNorm (dim , eps = eps , elementwise_affine = False )
@@ -131,7 +137,9 @@ def forward(
131137            ) *  torch .tanh (enc_gate_msa ).unsqueeze (1 )
132138            norm_encoder_hidden_states  =  self .norm3_context (encoder_hidden_states ) *  (1  +  enc_scale_mlp .unsqueeze (1 ))
133139            context_ff_output  =  self .ff_context (norm_encoder_hidden_states )
134-             encoder_hidden_states  =  encoder_hidden_states  +  self .norm4_context (context_ff_output ) *  torch .tanh (enc_gate_mlp ).unsqueeze (1 )
140+             encoder_hidden_states  =  encoder_hidden_states  +  self .norm4_context (context_ff_output ) *  torch .tanh (
141+                 enc_gate_mlp 
142+             ).unsqueeze (1 )
135143
136144        return  hidden_states , encoder_hidden_states 
137145
0 commit comments