Skip to content
6 changes: 4 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,8 @@ def __init__(
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1142,8 +1144,8 @@ def __init__(
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)

self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) if spatial_compression_ratio is None else spatial_compression_ratio
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) if temporal_compression_ratio is None else temporal_compression_ratio

# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
Expand Down
101 changes: 64 additions & 37 deletions src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ def check_inputs(
)

@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
Expand All @@ -447,6 +447,17 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size

latent_sample_coords = torch.meshgrid(
torch.arange(0, num_frames, patch_size_t, device=device),
torch.arange(0, height, patch_size, device=device),
torch.arange(0, width, patch_size, device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)

latents = latents.reshape(
batch_size,
-1,
Expand All @@ -458,7 +469,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
return latents, latent_coords

@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
Expand Down Expand Up @@ -503,10 +514,10 @@ def _prepare_non_first_frame_conditioning(
frame_index: int,
strength: float,
num_prefix_latent_frames: int = 2,
prefix_latents_mode: str = "soft",
prefix_latents_mode: str = "concat",
prefix_soft_conditioning_strength: float = 0.15,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_latent_frames = latents.size(2)
num_latent_frames = condition_latents.size(2)

if num_latent_frames < num_prefix_latent_frames:
raise ValueError(
Expand Down Expand Up @@ -544,6 +555,25 @@ def _prepare_non_first_frame_conditioning(

return latents, condition_latents, condition_latent_frames_mask

def trim_conditioning_sequence(
self, start_frame: int, sequence_num_frames: int, target_num_frames: int
):
"""
Trim a conditioning sequence to the allowed number of frames.
Args:
start_frame (int): The target frame number of the first frame in the sequence.
sequence_num_frames (int): The number of frames in the sequence.
target_num_frames (int): The target number of frames in the generated video.
Returns:
int: updated sequence length
"""
scale_factor = self.vae_temporal_compression_ratio
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
# Trim down to a multiple of temporal_scale_factor frames plus 1
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
return num_frames


def prepare_latents(
self,
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]],
Expand Down Expand Up @@ -573,13 +603,17 @@ def prepare_latents(
extra_conditioning_num_latents = (
0 # Number of extra conditioning latents added (should be removed before decoding)
)
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=dtype)
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)

for condition in conditions:
if condition.image is not None:
data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2)
elif condition.video is not None:
data = self.video_processor.preprocess_video(condition.vide, height, width)
data = self.video_processor.preprocess_video(condition.video, height, width)
num_frames_input = data.size(2)
num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames)
data = data[:, :, :num_frames_output]
data = data.to(device, dtype=dtype)
else:
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")

Expand All @@ -599,6 +633,7 @@ def prepare_latents(
latents[:, :, :num_cond_frames], condition_latents, condition.strength
)
condition_latent_frames_mask[:, :num_cond_frames] = condition.strength

else:
if num_data_frames > 1:
(
Expand All @@ -617,47 +652,39 @@ def prepare_latents(
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, condition.strength)
c_nlf = condition_latents.shape[2]
condition_latents = self._pack_latents(
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
condition_latents, rope_interpolation_scale = self._pack_latents(
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
)

rope_interpolation_scale = (
rope_interpolation_scale *
torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None]
)
rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
rope_interpolation_scale[:, 0] += condition.frame_index

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is compatible with what we do in LTXRotaryPosEmbed layer... We prepare the meshgrid there and only pass the interpolation scales from the pipeline. It seems like here we are preparing the meshgrid beforehand, which will be incorrect. I think we would have to do one of the following:

  • Make sure to only pass multiplicative interpolation scale without first multiplying with the latent_coords (the screenshot below shows how I handled it in the other PR)
  • If we're passing latent_coords, we will have to handle it differently in the transformer for LTX v0.9.0/v0.9.1 vs v0.9.5

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w

I think we have different ways to handle rope in our current code base, in general, I think it's more convenient/natural to prepare position ids (e.g. the image_ids, text_ids in flux or the grid here for ltx) at same time when we patchify the latents (e.g. pack_latent for ltx or flux). flux and ltx do this in pipeline and other models like lumina handle both together inside transformer with a patch embed

I think it is ok to have this flexibility for rope since it's something that slows us down for each integration. maybe a general rule is to try to follow closer to the original code base and fit it into one of the patterns that's easier for us to maintain.

conditioning_mask = torch.full(
condition_latents.shape[:2], condition.strength, device=device, dtype=dtype
)

rope_interpolation_scale = [
# TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation
# scale with the grid.
(self.vae_temporal_compression_ratio + condition.frame_index) / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
]
Comment on lines -627 to -633
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu Pardon my stupidity, but I can't seem to find if we're handling this + condition.frame_index part. Is this missing by any chance, or was I mistaken in trying to handle this here?

In the original code, this is what I was meaning to handle: https://github.com/Lightricks/LTX-Video/blob/496dc5058f4408dcb777282f3fb6377fb2da08e6/ltx_video/pipelines/pipeline_ltx_video.py#L1285

rope_interpolation_scale = (
torch.tensor(rope_interpolation_scale, device=device, dtype=dtype)
.view(-1, 1, 1, 1, 1)
.repeat(1, 1, c_nlf, latent_height, latent_width)
)
extra_conditioning_num_latents += condition_latents.size(1)

extra_conditioning_latents.append(condition_latents)
extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale)
extra_conditioning_mask.append(conditioning_mask)

latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
latents, rope_interpolation_scale = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
)
rope_interpolation_scale = [
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
]
rope_interpolation_scale = (
torch.tensor(rope_interpolation_scale, device=device, dtype=dtype)
.view(-1, 1, 1, 1, 1)
.repeat(1, 1, num_latent_frames, latent_height, latent_width)
conditioning_mask = condition_latent_frames_mask.gather(
1, rope_interpolation_scale[:, 0]
)
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size

rope_interpolation_scale = (
rope_interpolation_scale
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None]
)
rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)

if len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
Expand Down Expand Up @@ -864,7 +891,7 @@ def __call__(
frame_rate,
generator,
device,
torch.float32,
prompt_embeds.dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use float32 here and then typecast before sending into transformer, no? That way there won't be a downcast/upcast for CFG

)
init_latents = latents.clone()

Expand Down Expand Up @@ -955,8 +982,8 @@ def __call__(
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]

latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
latents, _ = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
)

if callback_on_step_end is not None:
Expand Down