Skip to content

Commit 4bd9e99

Browse files
committed
removed in-place sum, may affect backward propagation logic
1 parent 7747007 commit 4bd9e99

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def forward(
138138
**attention_kwargs,
139139
)
140140

141-
hidden_states += gate_msa * attn_hidden_states
142-
encoder_hidden_states += enc_gate_msa * attn_encoder_hidden_states
141+
hidden_states = hidden_states + gate_msa * attn_hidden_states
142+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
143143

144144
# norm & modulate
145145
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
@@ -150,8 +150,8 @@ def forward(
150150
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
151151
ff_output = self.ff(norm_hidden_states)
152152

153-
hidden_states += gate_ff * ff_output[:, text_seq_length:]
154-
encoder_hidden_states += enc_gate_ff * ff_output[:, :text_seq_length]
153+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
154+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
155155

156156
return hidden_states, encoder_hidden_states
157157

0 commit comments

Comments
 (0)