File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
src/diffusers/models/transformers Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -107,20 +107,21 @@ def forward(
107107 )
108108
109109 hidden_states = hidden_states + self .norm2 (attn_hidden_states ) * torch .tanh (gate_msa ).unsqueeze (1 )
110- hidden_states = self .norm3 (hidden_states ) * (1 + scale_mlp .unsqueeze (1 ))
110+ norm_hidden_states = self .norm3 (hidden_states ) * (1 + scale_mlp .unsqueeze (1 ))
111+
111112 if not self .context_pre_only :
112113 encoder_hidden_states = encoder_hidden_states + self .norm2_context (
113114 context_attn_hidden_states
114115 ) * torch .tanh (enc_gate_msa ).unsqueeze (1 )
115- encoder_hidden_states = encoder_hidden_states + self .norm3_context (encoder_hidden_states ) * (
116+ norm_encoder_hidden_states = encoder_hidden_states + self .norm3_context (encoder_hidden_states ) * (
116117 1 + enc_scale_mlp .unsqueeze (1 )
117118 )
118119
119- ff_output = self .ff (hidden_states )
120+ ff_output = self .ff (norm_hidden_states )
120121 hidden_states = hidden_states + ff_output * torch .tanh (gate_mlp ).unsqueeze (1 )
121122
122123 if not self .context_pre_only :
123- context_ff_output = self .ff_context (encoder_hidden_states )
124+ context_ff_output = self .ff_context (norm_encoder_hidden_states )
124125 encoder_hidden_states = encoder_hidden_states + context_ff_output * torch .tanh (enc_gate_mlp ).unsqueeze (0 )
125126
126127 return hidden_states , encoder_hidden_states
You can’t perform that action at this time.
0 commit comments