Skip to content

Commit a6d990c

Browse files
committed
update
1 parent 5316f4b commit a6d990c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def forward(
167167
spatial_shape = hidden_states.shape[-2:]
168168
spatial_noise = torch.randn(
169169
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
170-
)
171-
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, :, :]
170+
)[None]
171+
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
172172

173173
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
174174

@@ -183,8 +183,8 @@ def forward(
183183
spatial_shape = hidden_states.shape[-2:]
184184
spatial_noise = torch.randn(
185185
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
186-
)
187-
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, :, :]
186+
)[None]
187+
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
188188

189189
if self.norm3 is not None:
190190
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)

0 commit comments

Comments
 (0)