Skip to content

Commit ed2f7e3

Browse files
committed
add docs tests + more refactor
1 parent 00e9670 commit ed2f7e3

File tree

9 files changed

+409
-39
lines changed

9 files changed

+409
-39
lines changed

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
196196
- all
197197
- __call__
198198

199+
## LTXConditionPipeline
200+
201+
[[autodoc]] LTXConditionPipeline
202+
- all
203+
- __call__
204+
199205
## LTXPipelineOutput
200206

201207
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

scripts/convert_ltx_to_diffusers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
105105
"per_channel_statistics.mean-of-means": remove_keys_,
106106
"per_channel_statistics.mean-of-stds": remove_keys_,
107107
"model.diffusion_model": remove_keys_,
108+
"decoder.timestep_scale_multiplier": remove_keys_,
108109
}
109110

110111

@@ -270,6 +271,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
270271
"decoder_causal": False,
271272
"spatial_compression_ratio": 32,
272273
"temporal_compression_ratio": 8,
274+
"timestep_scale_multiplier": 1000.0,
273275
}
274276
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
275277
return config

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@
347347
"LDMTextToImagePipeline",
348348
"LEditsPPPipelineStableDiffusion",
349349
"LEditsPPPipelineStableDiffusionXL",
350+
"LTXConditionPipeline",
350351
"LTXImageToVideoPipeline",
351352
"LTXPipeline",
352353
"Lumina2Text2ImgPipeline",
@@ -857,6 +858,7 @@
857858
LDMTextToImagePipeline,
858859
LEditsPPPipelineStableDiffusion,
859860
LEditsPPPipelineStableDiffusionXL,
861+
LTXConditionPipeline,
860862
LTXImageToVideoPipeline,
861863
LTXPipeline,
862864
Lumina2Text2ImgPipeline,

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -921,12 +921,14 @@ def __init__(
921921
timestep_conditioning: bool = False,
922922
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
923923
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
924+
timestep_scale_multiplier: float = 1.0,
924925
) -> None:
925926
super().__init__()
926927

927928
self.patch_size = patch_size
928929
self.patch_size_t = patch_size_t
929930
self.out_channels = out_channels * patch_size**2
931+
self.timestep_scale_multiplier = timestep_scale_multiplier
930932

931933
block_out_channels = tuple(reversed(block_out_channels))
932934
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
@@ -981,9 +983,7 @@ def __init__(
981983
# timestep embedding
982984
self.time_embedder = None
983985
self.scale_shift_table = None
984-
self.timestep_scale_multiplier = None
985986
if timestep_conditioning:
986-
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
987987
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
988988
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
989989

@@ -992,7 +992,7 @@ def __init__(
992992
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
993993
hidden_states = self.conv_in(hidden_states)
994994

995-
if self.timestep_scale_multiplier is not None:
995+
if temb is not None:
996996
temb = temb * self.timestep_scale_multiplier
997997

998998
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -1107,6 +1107,7 @@ def __init__(
11071107
decoder_causal: bool = False,
11081108
spatial_compression_ratio: int = None,
11091109
temporal_compression_ratio: int = None,
1110+
timestep_scale_multiplier: float = 1.0,
11101111
) -> None:
11111112
super().__init__()
11121113

@@ -1137,6 +1138,7 @@ def __init__(
11371138
inject_noise=decoder_inject_noise,
11381139
upsample_residual=upsample_residual,
11391140
upsample_factor=upsample_factor,
1141+
timestep_scale_multiplier=timestep_scale_multiplier,
11401142
)
11411143

11421144
latents_mean = torch.zeros((latent_channels,), requires_grad=False)

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@
260260
]
261261
)
262262
_import_structure["latte"] = ["LattePipeline"]
263-
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
263+
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
264264
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
265265
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
266266
_import_structure["marigold"].extend(
@@ -610,7 +610,7 @@
610610
LEditsPPPipelineStableDiffusion,
611611
LEditsPPPipelineStableDiffusionXL,
612612
)
613-
from .ltx import LTXImageToVideoPipeline, LTXPipeline
613+
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
614614
from .lumina import LuminaText2ImgPipeline
615615
from .lumina2 import Lumina2Text2ImgPipeline
616616
from .marigold import (

src/diffusers/pipelines/ltx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
26+
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
2627
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
2728

2829
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -34,6 +35,7 @@
3435
from ...utils.dummy_torch_and_transformers_objects import *
3536
else:
3637
from .pipeline_ltx import LTXPipeline
38+
from .pipeline_ltx_condition import LTXConditionPipeline
3739
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
3840

3941
else:

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers import T5EncoderModel, T5TokenizerFast
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24+
from ...image_processor import PipelineImageInput
2425
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
2526
from ...models.autoencoders import AutoencoderKLLTXVideo
2627
from ...models.transformers import LTXVideoTransformer3DModel
@@ -45,12 +46,11 @@
4546
Examples:
4647
```py
4748
>>> import torch
48-
>>> from diffusers import LTXImageToVideoPipeline
49+
>>> from diffusers import LTXConditionPipeline
4950
>>> from diffusers.utils import export_to_video, load_image
5051
51-
>>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
52+
>>> pipe = LTXConditionPipeline.from_pretrained("YiYiXu/ltx-95", torch_dtype=torch.bfloat16)
5253
>>> pipe.to("cuda")
53-
5454
>>> image = load_image(
5555
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
5656
... )
@@ -405,6 +405,11 @@ def encode_prompt(
405405
def check_inputs(
406406
self,
407407
prompt,
408+
conditions,
409+
image,
410+
video,
411+
frame_index,
412+
strength,
408413
height,
409414
width,
410415
callback_on_step_end_tensor_inputs=None,
@@ -455,6 +460,26 @@ def check_inputs(
455460
f" {negative_prompt_attention_mask.shape}."
456461
)
457462

463+
if conditions is not None and (image is not None or video is not None):
464+
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
465+
466+
if conditions is None and (image is None and video is None):
467+
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
468+
469+
if conditions is None:
470+
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
471+
raise ValueError(
472+
"If `conditions` is not provided, `image` and `frame_index` must be of the same length."
473+
)
474+
elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength):
475+
raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.")
476+
elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index):
477+
raise ValueError(
478+
"If `conditions` is not provided, `video` and `frame_index` must be of the same length."
479+
)
480+
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
481+
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
482+
458483
@staticmethod
459484
def _prepare_video_ids(
460485
batch_size: int,
@@ -699,7 +724,8 @@ def prepare_latents(
699724
patch_size=self.transformer_spatial_patch_size,
700725
device=device,
701726
)
702-
video_ids_scaled = self._scale_video_ids(
727+
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
728+
video_ids = self._scale_video_ids(
703729
video_ids,
704730
scale_factor=self.vae_spatial_compression_ratio,
705731
scale_factor_t=self.vae_temporal_compression_ratio,
@@ -709,11 +735,10 @@ def prepare_latents(
709735
latents = self._pack_latents(
710736
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
711737
)
712-
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
713738

714739
if len(extra_conditioning_latents) > 0:
715740
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
716-
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids_scaled], dim=2)
741+
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
717742
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
718743

719744
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
@@ -742,7 +767,11 @@ def interrupt(self):
742767
@replace_example_docstring(EXAMPLE_DOC_STRING)
743768
def __call__(
744769
self,
745-
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]],
770+
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None,
771+
image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
772+
video: List[PipelineImageInput] = None,
773+
frame_index: Union[int, List[int]] = 0,
774+
strength: Union[float, List[float]] = 1.0,
746775
prompt: Union[str, List[str]] = None,
747776
negative_prompt: Optional[Union[str, List[str]]] = None,
748777
height: int = 512,
@@ -773,8 +802,19 @@ def __call__(
773802
Function invoked when calling the pipeline for generation.
774803
775804
Args:
776-
conditions (`List[LTXVideoCondition]`):
777-
The list of frame-conditioning items for the video generation.
805+
conditions (`List[LTXVideoCondition], *optional*`):
806+
The list of frame-conditioning items for the video generation.If not provided, conditions will be
807+
created using `image`, `video`, `frame_index` and `strength`.
808+
image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
809+
The image or images to condition the video generation. If not provided, one has to pass `video` or
810+
`conditions`.
811+
video (`List[PipelineImageInput]`, *optional*):
812+
The video to condition the video generation. If not provided, one has to pass `image` or `conditions`.
813+
frame_index (`int` or `List[int]`, *optional*):
814+
The frame index or frame indices at which the image or video will conditionally effect the video
815+
generation. If not provided, one has to pass `conditions`.
816+
strength (`float` or `List[float]`, *optional*):
817+
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
778818
prompt (`str` or `List[str]`, *optional*):
779819
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
780820
instead.
@@ -857,6 +897,11 @@ def __call__(
857897
# 1. Check inputs. Raise error if not correct
858898
self.check_inputs(
859899
prompt=prompt,
900+
conditions=conditions,
901+
image=image,
902+
video=video,
903+
frame_index=frame_index,
904+
strength=strength,
860905
height=height,
861906
width=width,
862907
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -878,6 +923,31 @@ def __call__(
878923
else:
879924
batch_size = prompt_embeds.shape[0]
880925

926+
if conditions is not None:
927+
if not isinstance(conditions, list):
928+
conditions = [conditions]
929+
930+
strength = [condition.strength for condition in conditions]
931+
frame_index = [condition.frame_index for condition in conditions]
932+
image = [condition.image for condition in conditions]
933+
video = [condition.video for condition in conditions]
934+
else:
935+
if not isinstance(image, list):
936+
image = [image]
937+
num_conditions = 1
938+
elif isinstance(image, list):
939+
num_conditions = len(image)
940+
if not isinstance(video, list):
941+
video = [video]
942+
num_conditions = 1
943+
elif isinstance(video, list):
944+
num_conditions = len(video)
945+
946+
if not isinstance(frame_index, list):
947+
frame_index = [frame_index] * num_conditions
948+
if not isinstance(strength, list):
949+
strength = [strength] * num_conditions
950+
881951
device = self._execution_device
882952

883953
# 3. Prepare text embeddings
@@ -905,17 +975,20 @@ def __call__(
905975
vae_dtype = self.vae.dtype
906976

907977
conditioning_tensors = []
908-
conditioning_strengths = []
909-
conditioning_start_frames = []
910-
911-
for condition in conditions:
912-
if condition.image is not None:
913-
condition_tensor = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2)
914-
elif condition.video is not None:
915-
condition_tensor = self.video_processor.preprocess_video(condition.video, height, width)
978+
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
979+
image, video, frame_index, strength
980+
):
981+
if condition_image is not None:
982+
condition_tensor = (
983+
self.video_processor.preprocess(condition_image, height, width)
984+
.unsqueeze(2)
985+
.to(device, dtype=vae_dtype)
986+
)
987+
elif condition_video is not None:
988+
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
916989
num_frames_input = condition_tensor.size(2)
917990
num_frames_output = self.trim_conditioning_sequence(
918-
condition.frame_index, num_frames_input, num_frames
991+
condition_frame_index, num_frames_input, num_frames
919992
)
920993
condition_tensor = condition_tensor[:, :, :num_frames_output]
921994
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
@@ -928,15 +1001,13 @@ def __call__(
9281001
f"but got {condition_tensor.size(2)} frames."
9291002
)
9301003
conditioning_tensors.append(condition_tensor)
931-
conditioning_strengths.append(condition.strength)
932-
conditioning_start_frames.append(condition.frame_index)
9331004

9341005
# 4. Prepare latent variables
9351006
num_channels_latents = self.transformer.config.in_channels
9361007
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
9371008
conditioning_tensors,
938-
conditioning_strengths,
939-
conditioning_start_frames,
1009+
strength,
1010+
frame_index,
9401011
batch_size=batch_size * num_videos_per_prompt,
9411012
num_channels_latents=num_channels_latents,
9421013
height=height,
@@ -1015,9 +1086,10 @@ def __call__(
10151086
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
10161087
timestep, _ = timestep.chunk(2)
10171088

1018-
denoised_latents = self.scheduler.step(-noise_pred, timestep, latents, return_dict=False)[0]
1019-
t_eps = 1e-6
1020-
tokens_to_denoise_mask = (t / 1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
1089+
denoised_latents = self.scheduler.step(
1090+
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
1091+
)[0]
1092+
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
10211093
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
10221094

10231095
if callback_on_step_end is not None:

0 commit comments

Comments
 (0)