Skip to content

Commit 7747007

Browse files
committed
simpler in-place expressions
1 parent a0c5ab2 commit 7747007

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def forward(
303303
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
304304

305305
if temb is not None:
306-
hidden_states.add_(self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None])
306+
hidden_states += self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
307307

308308
if zq is not None:
309309
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
@@ -322,7 +322,7 @@ def forward(
322322
else:
323323
inputs = self.conv_shortcut(inputs)
324324

325-
hidden_states.add_(inputs)
325+
hidden_states += inputs
326326
return hidden_states, new_conv_cache
327327

328328

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def forward(
138138
**attention_kwargs,
139139
)
140140

141-
hidden_states.add_(gate_msa * attn_hidden_states)
142-
encoder_hidden_states.add_(enc_gate_msa * attn_encoder_hidden_states)
141+
hidden_states += gate_msa * attn_hidden_states
142+
encoder_hidden_states += enc_gate_msa * attn_encoder_hidden_states
143143

144144
# norm & modulate
145145
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
@@ -150,8 +150,8 @@ def forward(
150150
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
151151
ff_output = self.ff(norm_hidden_states)
152152

153-
hidden_states.add_(gate_ff * ff_output[:, text_seq_length:])
154-
encoder_hidden_states.add_(enc_gate_ff * ff_output[:, :text_seq_length])
153+
hidden_states += gate_ff * ff_output[:, text_seq_length:]
154+
encoder_hidden_states += enc_gate_ff * ff_output[:, :text_seq_length]
155155

156156
return hidden_states, encoder_hidden_states
157157

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def forward(
106106
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
107107
)
108108

109-
hidden_states.add_(gate_msa.unsqueeze(1) * attn_hidden_states)
110-
encoder_hidden_states.add_(c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states)
109+
hidden_states += gate_msa.unsqueeze(1) * attn_hidden_states
110+
encoder_hidden_states += c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
111111

112112
# norm & modulate
113113
norm_hidden_states = self.norm2(hidden_states)
@@ -120,8 +120,8 @@ def forward(
120120
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
121121
ff_output = self.ff(norm_hidden_states)
122122

123-
hidden_states.add_(gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:])
124-
encoder_hidden_states.add_(c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length])
123+
hidden_states += gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124+
encoder_hidden_states += c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
125125

126126
if hidden_states.dtype == torch.float16:
127127
hidden_states = hidden_states.clip(-65504, 65504)

0 commit comments

Comments
 (0)