Skip to content

Commit 52d2ec3

Browse files
a-r-r-o-wyiyixuxu
authored andcommitted
vae fix
1 parent 661ab0d commit 52d2ec3

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,9 @@ def __init__(
992992
# timestep embedding
993993
self.time_embedder = None
994994
self.scale_shift_table = None
995+
self.timestep_scale_multiplier = None
995996
if timestep_conditioning:
997+
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
996998
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
997999
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
9981000

@@ -1001,6 +1003,9 @@ def __init__(
10011003
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
10021004
hidden_states = self.conv_in(hidden_states)
10031005

1006+
if self.timestep_scale_multiplier is not None:
1007+
temb = temb * self.timestep_scale_multiplier
1008+
10041009
if torch.is_grad_enabled() and self.gradient_checkpointing:
10051010
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
10061011

0 commit comments

Comments
 (0)