Skip to content

Commit 661ab0d

Browse files
committed
yiyi add testing lines
1 parent f950ba1 commit 661ab0d

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,12 @@ def forward(
507507
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
508508
else:
509509
hidden_states = resnet(hidden_states, temb, generator)
510+
print(f" after resnets: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
510511

511512
if self.downsamplers is not None:
512513
for downsampler in self.downsamplers:
513514
hidden_states = downsampler(hidden_states)
515+
print(f" after downsampler: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
514516

515517
return hidden_states
516518

@@ -841,6 +843,8 @@ def __init__(
841843
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
842844
r"""The forward method of the `LTXVideoEncoder3d` class."""
843845

846+
print(f" inside LTXVideoEncoder3d")
847+
print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
844848
p = self.patch_size
845849
p_t = self.patch_size_t
846850

@@ -854,7 +858,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
854858
)
855859
# Thanks for driving me insane with the weird patching order :(
856860
hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
861+
print(f" before conv_in: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
857862
hidden_states = self.conv_in(hidden_states)
863+
print(f" after conv_in: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
858864

859865
if torch.is_grad_enabled() and self.gradient_checkpointing:
860866
for down_block in self.down_blocks:
@@ -864,17 +870,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
864870
else:
865871
for down_block in self.down_blocks:
866872
hidden_states = down_block(hidden_states)
873+
print(f" after down_block: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
867874

868875
hidden_states = self.mid_block(hidden_states)
876+
print(f" after mid_block: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
869877

870878
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
879+
print(f" before conv_act: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
871880
hidden_states = self.conv_act(hidden_states)
881+
print(f" after conv_act: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
872882
hidden_states = self.conv_out(hidden_states)
883+
print(f" after conv_out: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
873884

874885
last_channel = hidden_states[:, -1:]
875886
last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
876887
hidden_states = torch.cat([hidden_states, last_channel], dim=1)
877-
888+
print(f" output: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
878889
return hidden_states
879890

880891

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,14 +626,19 @@ def prepare_latents(
626626
print(f" before encode: {data.shape}, {data.dtype}, {data.device}")
627627

628628
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
629+
print(f" after encode: {condition_latents.shape}, {condition_latents.dtype}, {condition_latents.device}")
630+
print(condition_latents[0,0,:3,:5,:5])
631+
condition_latents_before_normalize = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt")
632+
print(torch.sum((condition_latents_before_normalize - condition_latents).abs()))
633+
assert False
629634
condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std)
630635

631636
print(f" after normalize: {condition_latents.shape}")
632637
print(condition_latents[0,0,:3,:5,:5])
633638
condition_latents_loaded = torch.load("/raid/yiyi/LTX-Video/latents_normalized.pt")
634639
print(condition_latents_loaded.shape)
635640
print(condition_latents_loaded[0,0,:3,:5,:5])
636-
print(torch.sum((condition_latents_loaded - condition_latents).abs()))
641+
print(torch.sum((condition_latents_loaded.to(condition_latents.device) - condition_latents).abs()))
637642
assert False
638643

639644
num_data_frames = data.size(2)

0 commit comments

Comments
 (0)