Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
038ef09
Refactor `LTXConditionPipeline` to add text-only conditioning
tolgacangoz Mar 30, 2025
5f41f54
style
tolgacangoz Mar 30, 2025
8c14ffb
up
tolgacangoz Mar 30, 2025
7f2b8cd
Refactor `LTXConditionPipeline` to streamline condition handling and …
tolgacangoz Mar 30, 2025
b76c47e
Improve condition checks
tolgacangoz Mar 30, 2025
2e288ef
Simplify latents handling based on conditioning type
tolgacangoz Mar 30, 2025
149580e
Refactor rope_interpolation_scale preparation for clarity and efficiency
tolgacangoz Mar 30, 2025
ef03f6c
Update LTXConditionPipeline docstring to clarify supported input types
tolgacangoz Mar 30, 2025
a39a5ee
Add LTX Video 0.9.5 model to documentation
tolgacangoz Mar 30, 2025
f17dd3a
Clarify documentation to indicate support for text-only conditioning …
tolgacangoz Apr 1, 2025
4a67eca
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 1, 2025
c388ad7
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 1, 2025
4e5d4ab
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 3, 2025
8b27a68
refactor: comment out unused parameters in LTXConditionPipeline
tolgacangoz Apr 3, 2025
046af8e
fix: restore previously commented parameters in LTXConditionPipeline
tolgacangoz Apr 3, 2025
be98184
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 3, 2025
d78703a
fix: remove unused parameters from LTXConditionPipeline
tolgacangoz Apr 3, 2025
0599ff3
refactor: remove unnecessary lines in LTXConditionPipeline
tolgacangoz Apr 3, 2025
5ae1e54
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 3, 2025
a2d83a8
Merge branch 'make-LTX0.9.5-works-with-text-to-video' of github.com:t…
tolgacangoz Apr 3, 2025
fc9cfa8
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 3, 2025
9cb4a01
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 4, 2025
c2e15dc
Merge branch 'main' into make-LTX0.9.5-works-with-text-to-video
tolgacangoz Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/api/pipelines/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Available models:
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |

Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.

Expand Down
251 changes: 132 additions & 119 deletions src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -75,6 +75,7 @@

>>> # Generate video
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> # Text-only conditioning is also supported without the need to pass `conditions`
>>> video = pipe(
... conditions=[condition1, condition2],
... prompt=prompt,
Expand Down Expand Up @@ -223,7 +224,7 @@ def retrieve_latents(

class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for image-to-video generation.
Pipeline for text/image/video-to-video generation.

Reference: https://github.com/Lightricks/LTX-Video

Expand Down Expand Up @@ -482,9 +483,6 @@ def check_inputs(
if conditions is not None and (image is not None or video is not None):
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")

if conditions is None and (image is None and video is None):
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")

if conditions is None:
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
raise ValueError(
Expand Down Expand Up @@ -642,9 +640,9 @@ def add_noise_to_image_conditioning_latents(

def prepare_latents(
self,
conditions: List[torch.Tensor],
condition_strength: List[float],
condition_frame_index: List[int],
conditions: Optional[List[torch.Tensor]] = None,
condition_strength: Optional[List[float]] = None,
condition_frame_index: Optional[List[int]] = None,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
Expand All @@ -654,85 +652,88 @@ def prepare_latents(
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio

shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)

extra_conditioning_latents = []
extra_conditioning_video_ids = []
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
condition_latents, self.vae.latents_mean, self.vae.latents_std
).to(device, dtype=dtype)

num_data_frames = data.size(2)
num_cond_frames = condition_latents.size(2)

if frame_index == 0:
latents[:, :, :num_cond_frames] = torch.lerp(
latents[:, :, :num_cond_frames], condition_latents, strength
)
condition_latent_frames_mask[:, :num_cond_frames] = strength
if len(conditions) > 0:
condition_latent_frames_mask = torch.zeros(
(batch_size, num_latent_frames), device=device, dtype=torch.float32
)

else:
if num_data_frames > 1:
if num_cond_frames < num_prefix_latent_frames:
raise ValueError(
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
)

if num_cond_frames > num_prefix_latent_frames:
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
latents[:, :, start_frame:end_frame] = torch.lerp(
latents[:, :, start_frame:end_frame],
condition_latents[:, :, num_prefix_latent_frames:],
strength,
)
condition_latent_frames_mask[:, start_frame:end_frame] = strength
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]

noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, strength)

condition_video_ids = self._prepare_video_ids(
batch_size,
condition_latents.size(2),
latent_height,
latent_width,
patch_size=self.transformer_spatial_patch_size,
patch_size_t=self.transformer_temporal_patch_size,
device=device,
)
condition_video_ids = self._scale_video_ids(
condition_video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=frame_index,
device=device,
)
condition_latents = self._pack_latents(
condition_latents,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
condition_conditioning_mask = torch.full(
condition_latents.shape[:2], strength, device=device, dtype=dtype
)
extra_conditioning_latents = []
extra_conditioning_video_ids = []
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
condition_latents, self.vae.latents_mean, self.vae.latents_std
).to(device, dtype=dtype)

num_data_frames = data.size(2)
num_cond_frames = condition_latents.size(2)

if frame_index == 0:
latents[:, :, :num_cond_frames] = torch.lerp(
latents[:, :, :num_cond_frames], condition_latents, strength
)
condition_latent_frames_mask[:, :num_cond_frames] = strength

else:
if num_data_frames > 1:
if num_cond_frames < num_prefix_latent_frames:
raise ValueError(
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
)

if num_cond_frames > num_prefix_latent_frames:
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
latents[:, :, start_frame:end_frame] = torch.lerp(
latents[:, :, start_frame:end_frame],
condition_latents[:, :, num_prefix_latent_frames:],
strength,
)
condition_latent_frames_mask[:, start_frame:end_frame] = strength
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]

noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, strength)

condition_video_ids = self._prepare_video_ids(
batch_size,
condition_latents.size(2),
latent_height,
latent_width,
patch_size=self.transformer_spatial_patch_size,
patch_size_t=self.transformer_temporal_patch_size,
device=device,
)
condition_video_ids = self._scale_video_ids(
condition_video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=frame_index,
device=device,
)
condition_latents = self._pack_latents(
condition_latents,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
condition_conditioning_mask = torch.full(
condition_latents.shape[:2], strength, device=device, dtype=dtype
)

extra_conditioning_latents.append(condition_latents)
extra_conditioning_video_ids.append(condition_video_ids)
extra_conditioning_mask.append(condition_conditioning_mask)
extra_conditioning_num_latents += condition_latents.size(1)
extra_conditioning_latents.append(condition_latents)
extra_conditioning_video_ids.append(condition_video_ids)
extra_conditioning_mask.append(condition_conditioning_mask)
extra_conditioning_num_latents += condition_latents.size(1)

video_ids = self._prepare_video_ids(
batch_size,
Expand All @@ -743,7 +744,10 @@ def prepare_latents(
patch_size=self.transformer_spatial_patch_size,
device=device,
)
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
if len(conditions) > 0:
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
else:
conditioning_mask, extra_conditioning_num_latents = None, 0
video_ids = self._scale_video_ids(
video_ids,
scale_factor=self.vae_spatial_compression_ratio,
Expand All @@ -755,7 +759,7 @@ def prepare_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)

if len(extra_conditioning_latents) > 0:
if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
Expand Down Expand Up @@ -955,7 +959,7 @@ def __call__(
frame_index = [condition.frame_index for condition in conditions]
image = [condition.image for condition in conditions]
video = [condition.video for condition in conditions]
else:
elif image is not None or video is not None:
if not isinstance(image, list):
image = [image]
num_conditions = 1
Expand Down Expand Up @@ -999,32 +1003,34 @@ def __call__(
vae_dtype = self.vae.dtype

conditioning_tensors = []
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
image, video, frame_index, strength
):
if condition_image is not None:
condition_tensor = (
self.video_processor.preprocess(condition_image, height, width)
.unsqueeze(2)
.to(device, dtype=vae_dtype)
)
elif condition_video is not None:
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
num_frames_input = condition_tensor.size(2)
num_frames_output = self.trim_conditioning_sequence(
condition_frame_index, num_frames_input, num_frames
)
condition_tensor = condition_tensor[:, :, :num_frames_output]
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
else:
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")

if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
raise ValueError(
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
f"but got {condition_tensor.size(2)} frames."
)
conditioning_tensors.append(condition_tensor)
is_conditioning_image_or_video = image is not None or video is not None
if is_conditioning_image_or_video:
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
image, video, frame_index, strength
):
if condition_image is not None:
condition_tensor = (
self.video_processor.preprocess(condition_image, height, width)
.unsqueeze(2)
.to(device, dtype=vae_dtype)
)
elif condition_video is not None:
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
num_frames_input = condition_tensor.size(2)
num_frames_output = self.trim_conditioning_sequence(
condition_frame_index, num_frames_input, num_frames
)
condition_tensor = condition_tensor[:, :, :num_frames_output]
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
else:
raise ValueError("Either `image` or `video` must be provided for conditioning.")

if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
raise ValueError(
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
f"but got {condition_tensor.size(2)} frames."
)
conditioning_tensors.append(condition_tensor)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
Expand All @@ -1045,7 +1051,7 @@ def __call__(
video_coords = video_coords.float()
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)

init_latents = latents.clone()
init_latents = latents.clone() if is_conditioning_image_or_video else None

if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0)
Expand All @@ -1065,15 +1071,15 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# 7. Denoising loop
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

self._current_timestep = t

if image_cond_noise_scale > 0:
if image_cond_noise_scale > 0 and init_latents is not None:
# Add timestep-dependent noise to the hard-conditioning latents
# This helps with motion continuity, especially when conditioned on a single frame
latents = self.add_noise_to_image_conditioning_latents(
Expand All @@ -1086,16 +1092,18 @@ def __call__(
)

latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
conditioning_mask_model_input = (
torch.cat([conditioning_mask, conditioning_mask])
if self.do_classifier_free_guidance
else conditioning_mask
)
if is_conditioning_image_or_video:
conditioning_mask_model_input = (
torch.cat([conditioning_mask, conditioning_mask])
if self.do_classifier_free_guidance
else conditioning_mask
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)

noise_pred = self.transformer(
hidden_states=latent_model_input,
Expand All @@ -1115,8 +1123,11 @@ def __call__(
denoised_latents = self.scheduler.step(
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
)[0]
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
if is_conditioning_image_or_video:
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
else:
latents = denoised_latents

if callback_on_step_end is not None:
callback_kwargs = {}
Expand All @@ -1134,7 +1145,9 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()

latents = latents[:, extra_conditioning_num_latents:]
if is_conditioning_image_or_video:
latents = latents[:, extra_conditioning_num_latents:]

latents = self._unpack_latents(
latents,
latent_num_frames,
Expand Down
Loading