Skip to content

Commit 2ce09f8

Browse files
committed
removed in-place sum, may affect backward propagation logic
1 parent 4bd9e99 commit 2ce09f8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def forward(
106106
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
107107
)
108108

109-
hidden_states += gate_msa.unsqueeze(1) * attn_hidden_states
110-
encoder_hidden_states += c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
109+
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
110+
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
111111

112112
# norm & modulate
113113
norm_hidden_states = self.norm2(hidden_states)
@@ -120,8 +120,8 @@ def forward(
120120
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
121121
ff_output = self.ff(norm_hidden_states)
122122

123-
hidden_states += gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124-
encoder_hidden_states += c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
123+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
125125

126126
if hidden_states.dtype == torch.float16:
127127
hidden_states = hidden_states.clip(-65504, 65504)

0 commit comments

Comments
 (0)