Skip to content
6 changes: 4 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,8 @@ def __init__(
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1142,8 +1144,8 @@ def __init__(
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)

self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio

# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
Expand Down
77 changes: 55 additions & 22 deletions src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ def check_inputs(
)

@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
Expand All @@ -447,6 +447,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size

latent_sample_coords = torch.meshgrid(
torch.arange(0, num_frames, patch_size_t, device=device),
torch.arange(0, height, patch_size, device=device),
torch.arange(0, width, patch_size, device=device),
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)

latents = latents.reshape(
batch_size,
-1,
Expand All @@ -458,7 +468,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
return latents, latent_coords

@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
Expand Down Expand Up @@ -544,6 +554,25 @@ def _prepare_non_first_frame_conditioning(

return latents, condition_latents, condition_latent_frames_mask

def trim_conditioning_sequence(
self, start_frame: int, sequence_num_frames: int, target_num_frames: int
):
"""
Trim a conditioning sequence to the allowed number of frames.
Args:
start_frame (int): The target frame number of the first frame in the sequence.
sequence_num_frames (int): The number of frames in the sequence.
target_num_frames (int): The target number of frames in the generated video.
Returns:
int: updated sequence length
"""
scale_factor = self.vae_temporal_compression_ratio
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
# Trim down to a multiple of temporal_scale_factor frames plus 1
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
return num_frames


def prepare_latents(
self,
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]],
Expand Down Expand Up @@ -579,7 +608,11 @@ def prepare_latents(
if condition.image is not None:
data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2)
elif condition.video is not None:
data = self.video_processor.preprocess_video(condition.vide, height, width)
data = self.video_processor.preprocess_video(condition.video, height, width)
num_frames_input = data.size(2)
num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames)
data = data[:, :, :num_frames_output]
data = data.to(device, dtype=dtype)
else:
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")

Expand All @@ -599,6 +632,7 @@ def prepare_latents(
latents[:, :, :num_cond_frames], condition_latents, condition.strength
)
condition_latent_frames_mask[:, :num_cond_frames] = condition.strength
# YiYi TODO: code path not tested
else:
if num_data_frames > 1:
(
Expand All @@ -617,8 +651,8 @@ def prepare_latents(
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, condition.strength)
c_nlf = condition_latents.shape[2]
condition_latents = self._pack_latents(
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
condition_latents, condition_latent_coords = self._pack_latents(
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
)
conditioning_mask = torch.full(
condition_latents.shape[:2], condition.strength, device=device, dtype=dtype
Expand All @@ -642,23 +676,22 @@ def prepare_latents(
extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale)
extra_conditioning_mask.append(conditioning_mask)

latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
latents, latent_coords = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
)
rope_interpolation_scale = [
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
]
rope_interpolation_scale = (
torch.tensor(rope_interpolation_scale, device=device, dtype=dtype)
.view(-1, 1, 1, 1, 1)
.repeat(1, 1, num_latent_frames, latent_height, latent_width)
pixel_coords = (
latent_coords
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
)
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)

rope_interpolation_scale = pixel_coords

conditioning_mask = condition_latent_frames_mask.gather(
1, latent_coords[:, 0]
)

# YiYi TODO: code path not tested
if len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
rope_interpolation_scale = torch.cat(
Expand Down Expand Up @@ -864,7 +897,7 @@ def __call__(
frame_rate,
generator,
device,
torch.float32,
prompt_embeds.dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use float32 here and then typecast before sending into transformer, no? That way there won't be a downcast/upcast for CFG

)
init_latents = latents.clone()

Expand Down Expand Up @@ -955,8 +988,8 @@ def __call__(
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]

latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
latents, _ = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
)

if callback_on_step_end is not None:
Expand Down