Skip to content

Commit e202e46

Browse files
committed
up
1 parent 14a2282 commit e202e46

File tree

2 files changed

+59
-24
lines changed

2 files changed

+59
-24
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,8 @@ def __init__(
11051105
scaling_factor: float = 1.0,
11061106
encoder_causal: bool = True,
11071107
decoder_causal: bool = False,
1108+
spatial_compression_ratio: int = None,
1109+
temporal_compression_ratio: int = None,
11081110
) -> None:
11091111
super().__init__()
11101112

@@ -1142,8 +1144,8 @@ def __init__(
11421144
self.register_buffer("latents_mean", latents_mean, persistent=True)
11431145
self.register_buffer("latents_std", latents_std, persistent=True)
11441146

1145-
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
1146-
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
1147+
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio
1148+
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio
11471149

11481150
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
11491151
# to perform decoding of a single video latent at a time.

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,8 @@ def check_inputs(
437437
)
438438

439439
@staticmethod
440-
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
441-
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
440+
# adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
441+
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor:
442442
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
443443
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
444444
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
@@ -447,6 +447,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
447447
post_patch_num_frames = num_frames // patch_size_t
448448
post_patch_height = height // patch_size
449449
post_patch_width = width // patch_size
450+
451+
latent_sample_coords = torch.meshgrid(
452+
torch.arange(0, num_frames, patch_size_t, device=device),
453+
torch.arange(0, height, patch_size, device=device),
454+
torch.arange(0, width, patch_size, device=device),
455+
)
456+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
457+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
458+
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
459+
450460
latents = latents.reshape(
451461
batch_size,
452462
-1,
@@ -458,7 +468,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
458468
patch_size,
459469
)
460470
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
461-
return latents
471+
return latents, latent_coords
462472

463473
@staticmethod
464474
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
@@ -544,6 +554,25 @@ def _prepare_non_first_frame_conditioning(
544554

545555
return latents, condition_latents, condition_latent_frames_mask
546556

557+
def trim_conditioning_sequence(
558+
self, start_frame: int, sequence_num_frames: int, target_num_frames: int
559+
):
560+
"""
561+
Trim a conditioning sequence to the allowed number of frames.
562+
Args:
563+
start_frame (int): The target frame number of the first frame in the sequence.
564+
sequence_num_frames (int): The number of frames in the sequence.
565+
target_num_frames (int): The target number of frames in the generated video.
566+
Returns:
567+
int: updated sequence length
568+
"""
569+
scale_factor = self.vae_temporal_compression_ratio
570+
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
571+
# Trim down to a multiple of temporal_scale_factor frames plus 1
572+
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
573+
return num_frames
574+
575+
547576
def prepare_latents(
548577
self,
549578
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]],
@@ -579,7 +608,11 @@ def prepare_latents(
579608
if condition.image is not None:
580609
data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2)
581610
elif condition.video is not None:
582-
data = self.video_processor.preprocess_video(condition.vide, height, width)
611+
data = self.video_processor.preprocess_video(condition.video, height, width)
612+
num_frames_input = data.size(2)
613+
num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames)
614+
data = data[:, :, :num_frames_output]
615+
data = data.to(device, dtype=dtype)
583616
else:
584617
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
585618

@@ -599,6 +632,7 @@ def prepare_latents(
599632
latents[:, :, :num_cond_frames], condition_latents, condition.strength
600633
)
601634
condition_latent_frames_mask[:, :num_cond_frames] = condition.strength
635+
# YiYi TODO: code path not tested
602636
else:
603637
if num_data_frames > 1:
604638
(
@@ -617,8 +651,8 @@ def prepare_latents(
617651
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
618652
condition_latents = torch.lerp(noise, condition_latents, condition.strength)
619653
c_nlf = condition_latents.shape[2]
620-
condition_latents = self._pack_latents(
621-
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
654+
condition_latents, condition_latent_coords = self._pack_latents(
655+
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
622656
)
623657
conditioning_mask = torch.full(
624658
condition_latents.shape[:2], condition.strength, device=device, dtype=dtype
@@ -642,23 +676,22 @@ def prepare_latents(
642676
extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale)
643677
extra_conditioning_mask.append(conditioning_mask)
644678

645-
latents = self._pack_latents(
646-
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
679+
latents, latent_coords = self._pack_latents(
680+
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
647681
)
648-
rope_interpolation_scale = [
649-
self.vae_temporal_compression_ratio / frame_rate,
650-
self.vae_spatial_compression_ratio,
651-
self.vae_spatial_compression_ratio,
652-
]
653-
rope_interpolation_scale = (
654-
torch.tensor(rope_interpolation_scale, device=device, dtype=dtype)
655-
.view(-1, 1, 1, 1, 1)
656-
.repeat(1, 1, num_latent_frames, latent_height, latent_width)
682+
pixel_coords = (
683+
latent_coords
684+
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
657685
)
658-
conditioning_mask = self._pack_latents(
659-
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
686+
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
687+
688+
rope_interpolation_scale = pixel_coords
689+
690+
conditioning_mask = condition_latent_frames_mask.gather(
691+
1, latent_coords[:, 0]
660692
)
661693

694+
# YiYi TODO: code path not tested
662695
if len(extra_conditioning_latents) > 0:
663696
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
664697
rope_interpolation_scale = torch.cat(
@@ -864,7 +897,7 @@ def __call__(
864897
frame_rate,
865898
generator,
866899
device,
867-
torch.float32,
900+
prompt_embeds.dtype,
868901
)
869902
init_latents = latents.clone()
870903

@@ -955,8 +988,8 @@ def __call__(
955988
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
956989

957990
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
958-
latents = self._pack_latents(
959-
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
991+
latents, _ = self._pack_latents(
992+
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
960993
)
961994

962995
if callback_on_step_end is not None:

0 commit comments

Comments
 (0)