Skip to content

Commit 1e9bc91

Browse files
committed
fix
1 parent be5bbe5 commit 1e9bc91

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,21 @@ def forward(
107107
)
108108

109109
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
110-
hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
110+
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
111+
111112
if not self.context_pre_only:
112113
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
113114
context_attn_hidden_states
114115
) * torch.tanh(enc_gate_msa).unsqueeze(1)
115-
encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (
116+
norm_encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (
116117
1 + enc_scale_mlp.unsqueeze(1)
117118
)
118119

119-
ff_output = self.ff(hidden_states)
120+
ff_output = self.ff(norm_hidden_states)
120121
hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1)
121122

122123
if not self.context_pre_only:
123-
context_ff_output = self.ff_context(encoder_hidden_states)
124+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
124125
encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0)
125126

126127
return hidden_states, encoder_hidden_states

0 commit comments

Comments
 (0)