Skip to content

Commit a0c5ab2

Browse files
committed
in-place sums
1 parent 249a8a2 commit a0c5ab2

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def forward(
138138
**attention_kwargs,
139139
)
140140

141-
hidden_states = hidden_states + gate_msa * attn_hidden_states
142-
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
141+
hidden_states.add_(gate_msa * attn_hidden_states)
142+
encoder_hidden_states.add_(enc_gate_msa * attn_encoder_hidden_states)
143143

144144
# norm & modulate
145145
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
@@ -150,8 +150,8 @@ def forward(
150150
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
151151
ff_output = self.ff(norm_hidden_states)
152152

153-
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
154-
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
153+
hidden_states.add_(gate_ff * ff_output[:, text_seq_length:])
154+
encoder_hidden_states.add_(enc_gate_ff * ff_output[:, :text_seq_length])
155155

156156
return hidden_states, encoder_hidden_states
157157

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def forward(
106106
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
107107
)
108108

109-
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
110-
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
109+
hidden_states.add_(gate_msa.unsqueeze(1) * attn_hidden_states)
110+
encoder_hidden_states.add_(c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states)
111111

112112
# norm & modulate
113113
norm_hidden_states = self.norm2(hidden_states)
@@ -120,8 +120,8 @@ def forward(
120120
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
121121
ff_output = self.ff(norm_hidden_states)
122122

123-
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124-
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
123+
hidden_states.add_(gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:])
124+
encoder_hidden_states.add_(c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length])
125125

126126
if hidden_states.dtype == torch.float16:
127127
hidden_states = hidden_states.clip(-65504, 65504)

0 commit comments

Comments
 (0)