Skip to content

Commit 658d533

Browse files
committed
up
1 parent 9fa964b commit 658d533

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def forward(
507507
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
508508
else:
509509
hidden_states = resnet(hidden_states, temb, generator)
510-
print(f" after resnets: {hidden_states.shape}, {hidden_states[0,0,:3,:3,:3]}")
510+
511511

512512
if self.downsamplers is not None:
513513
for downsampler in self.downsamplers:
@@ -1116,6 +1116,8 @@ def __init__(
11161116
scaling_factor: float = 1.0,
11171117
encoder_causal: bool = True,
11181118
decoder_causal: bool = False,
1119+
spatial_compression_ratio: int = None,
1120+
temporal_compression_ratio: int = None,
11191121
) -> None:
11201122
super().__init__()
11211123

@@ -1153,8 +1155,9 @@ def __init__(
11531155
self.register_buffer("latents_mean", latents_mean, persistent=True)
11541156
self.register_buffer("latents_std", latents_std, persistent=True)
11551157

1156-
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
1157-
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
1158+
1159+
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio
1160+
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio
11581161

11591162
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
11601163
# to perform decoding of a single video latent at a time.

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,11 @@ def check_inputs(
435435
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
436436
f" {negative_prompt_attention_mask.shape}."
437437
)
438-
438+
439+
439440
@staticmethod
440441
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
441-
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
442+
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor:
442443
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
443444
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
444445
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
@@ -447,6 +448,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
447448
post_patch_num_frames = num_frames // patch_size_t
448449
post_patch_height = height // patch_size
449450
post_patch_width = width // patch_size
451+
452+
latent_sample_coords = torch.meshgrid(
453+
torch.arange(0, num_frames, patch_size_t, device=device),
454+
torch.arange(0, height, patch_size, device=device),
455+
torch.arange(0, width, patch_size, device=device),
456+
)
457+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
458+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
459+
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
460+
450461
latents = latents.reshape(
451462
batch_size,
452463
-1,
@@ -458,7 +469,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
458469
patch_size,
459470
)
460471
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
461-
return latents
472+
return latents, latent_coords
462473

463474
@staticmethod
464475
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
@@ -588,6 +599,7 @@ def prepare_latents(
588599

589600
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
590601
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
602+
latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype)
591603

592604
extra_conditioning_latents = []
593605
extra_conditioning_rope_interpolation_scales = []
@@ -605,14 +617,6 @@ def prepare_latents(
605617
num_frames_input = data.size(2)
606618
num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames)
607619
data = data[:, :, :num_frames_output]
608-
609-
print(data.shape)
610-
print(data[0,0,:3,:5,:5])
611-
data_loaded = torch.load("/raid/yiyi/LTX-Video/media_item.pt")
612-
print(data_loaded.shape)
613-
print(data_loaded[0,0,:3,:5,:5])
614-
print(torch.sum((data_loaded - data).abs()))
615-
print(f" dtype:{dtype}, device:{device}")
616620
data = data.to(device, dtype=torch.bfloat16)
617621
else:
618622
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
@@ -623,23 +627,11 @@ def prepare_latents(
623627
f"but got {data.size(2)} frames."
624628
)
625629

626-
print(f" before encode: {data.shape}, {data.dtype}, {data.device}")
627-
628630
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
629-
print(f" after encode: {condition_latents.shape}, {condition_latents.dtype}, {condition_latents.device}")
630-
print(condition_latents[0,0,:3,:5,:5])
631-
condition_latents_before_normalize = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt")
632-
print(torch.sum((condition_latents_before_normalize - condition_latents).abs()))
633-
assert False
634-
condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std)
635-
636-
print(f" after normalize: {condition_latents.shape}")
637-
print(condition_latents[0,0,:3,:5,:5])
638-
condition_latents_loaded = torch.load("/raid/yiyi/LTX-Video/latents_normalized.pt")
639-
print(condition_latents_loaded.shape)
640-
print(condition_latents_loaded[0,0,:3,:5,:5])
641-
print(torch.sum((condition_latents_loaded.to(condition_latents.device) - condition_latents).abs()))
642-
assert False
631+
condition_latents_loaded = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt").to(condition_latents.device)
632+
print(f" condition_latents(loaded): {condition_latents_loaded.shape}, {condition_latents_loaded[0,0,:3,:3,:3]}")
633+
print(f" condition_latents: {condition_latents.shape}, {condition_latents[0,0,:3,:3,:3]}")
634+
condition_latents = self._normalize_latents(condition_latents_loaded, self.vae.latents_mean, self.vae.latents_std)
643635

644636
num_data_frames = data.size(2)
645637
num_cond_frames = condition_latents.size(2)
@@ -667,7 +659,7 @@ def prepare_latents(
667659
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
668660
condition_latents = torch.lerp(noise, condition_latents, condition.strength)
669661
c_nlf = condition_latents.shape[2]
670-
condition_latents = self._pack_latents(
662+
condition_latents, latent_coords = self._pack_latents(
671663
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
672664
)
673665
conditioning_mask = torch.full(
@@ -692,30 +684,37 @@ def prepare_latents(
692684
extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale)
693685
extra_conditioning_mask.append(conditioning_mask)
694686

695-
latents = self._pack_latents(
696-
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
687+
latents, latent_coords = self._pack_latents(
688+
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
697689
)
698-
rope_interpolation_scale = [
699-
self.vae_temporal_compression_ratio / frame_rate,
700-
self.vae_spatial_compression_ratio,
701-
self.vae_spatial_compression_ratio,
702-
]
703-
rope_interpolation_scale = (
704-
torch.tensor(rope_interpolation_scale, device=device, dtype=dtype)
705-
.view(-1, 1, 1, 1, 1)
706-
.repeat(1, 1, num_latent_frames, latent_height, latent_width)
690+
691+
pixel_coords = (
692+
latent_coords
693+
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
707694
)
708-
conditioning_mask = self._pack_latents(
709-
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
695+
696+
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
697+
698+
rope_interpolation_scale = pixel_coords
699+
700+
conditioning_mask = condition_latent_frames_mask.gather(
701+
1, latent_coords[:, 0]
710702
)
711703

704+
## YiYi Todo: not looked into yet
712705
if len(extra_conditioning_latents) > 0:
713706
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
714707
rope_interpolation_scale = torch.cat(
715708
[*extra_conditioning_rope_interpolation_scales, rope_interpolation_scale], dim=2
716709
)
717710
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
718711

712+
713+
print(f" latents (after pack): {latents.shape}, {latents[0,:3,:3]}")
714+
print(f" conditioning_mask: {conditioning_mask.shape}, {conditioning_mask[0,:10]}")
715+
print(f" rope_interpolation_scale: {rope_interpolation_scale.shape}, {rope_interpolation_scale[0,:3,:3]}")
716+
print(f" extra_conditioning_num_latents: {extra_conditioning_num_latents}")
717+
assert False
719718
return latents, conditioning_mask, rope_interpolation_scale, extra_conditioning_num_latents
720719

721720
@property
@@ -914,7 +913,7 @@ def __call__(
914913
frame_rate,
915914
generator,
916915
device,
917-
torch.float32,
916+
prompt_embeds.dtype,
918917
)
919918
init_latents = latents.clone()
920919

0 commit comments

Comments
 (0)