Skip to content

Commit 1f008fc

Browse files
committed
image2video
1 parent e10b7e7 commit 1f008fc

File tree

9 files changed

+504
-57
lines changed

9 files changed

+504
-57
lines changed

docs/source/en/api/pipelines/ltx.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2828
- all
2929
- __call__
3030

31+
## LTXImageToVideoPipeline
32+
33+
[[autodoc]] LTXImageToVideoPipeline
34+
- all
35+
- __call__
36+
3137
## LTXPipelineOutput
3238

3339
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@
317317
"LDMTextToImagePipeline",
318318
"LEditsPPPipelineStableDiffusion",
319319
"LEditsPPPipelineStableDiffusionXL",
320+
"LTXImageToVideoPipeline",
320321
"LTXPipeline",
321322
"LuminaText2ImgPipeline",
322323
"MarigoldDepthPipeline",
@@ -790,6 +791,7 @@
790791
LDMTextToImagePipeline,
791792
LEditsPPPipelineStableDiffusion,
792793
LEditsPPPipelineStableDiffusionXL,
794+
LTXImageToVideoPipeline,
793795
LTXPipeline,
794796
LuminaText2ImgPipeline,
795797
MarigoldDepthPipeline,

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def __init__(
802802
)
803803

804804
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
805-
latents_std = torch.zeros((latent_channels,), requires_grad=False)
805+
latents_std = torch.ones((latent_channels,), requires_grad=False)
806806
self.register_buffer("latents_mean", latents_mean, persistent=True)
807807
self.register_buffer("latents_std", latents_std, persistent=True)
808808

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,12 @@ def __init__(
116116
self.theta = theta
117117

118118
def forward(
119-
self, hidden_states: torch.Tensor, num_frames: int, height: int, width: int, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None
119+
self,
120+
hidden_states: torch.Tensor,
121+
num_frames: int,
122+
height: int,
123+
width: int,
124+
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
120125
) -> Tuple[torch.Tensor, torch.Tensor]:
121126
batch_size = hidden_states.size(0)
122127

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@
245245
]
246246
)
247247
_import_structure["latte"] = ["LattePipeline"]
248-
_import_structure["ltx"] = ["LTXPipeline"]
248+
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
249249
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
250250
_import_structure["marigold"].extend(
251251
[
@@ -578,7 +578,7 @@
578578
LEditsPPPipelineStableDiffusion,
579579
LEditsPPPipelineStableDiffusionXL,
580580
)
581-
from .ltx import LTXPipeline
581+
from .ltx import LTXImageToVideoPipeline, LTXPipeline
582582
from .lumina import LuminaText2ImgPipeline
583583
from .marigold import (
584584
MarigoldDepthPipeline,

src/diffusers/pipelines/ltx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
26+
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
2627

2728
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2829
try:
@@ -33,6 +34,7 @@
3334
from ...utils.dummy_torch_and_transformers_objects import *
3435
else:
3536
from .pipeline_ltx import LTXPipeline
37+
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
3638

3739
else:
3840
import sys

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,24 @@ def _unpack_latents(
415415
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
416416
return latents
417417

418+
@staticmethod
419+
def _normalize_latents(
420+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
421+
) -> torch.Tensor:
422+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
423+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
424+
latents = (latents - latents_mean) * scaling_factor / latents_std
425+
return latents
426+
427+
@staticmethod
428+
def _denormalize_latents(
429+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
430+
) -> torch.Tensor:
431+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
432+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
433+
latents = latents * latents_std / scaling_factor + latents_mean
434+
return latents
435+
418436
def prepare_latents(
419437
self,
420438
batch_size: int = 1,
@@ -443,7 +461,9 @@ def prepare_latents(
443461
)
444462

445463
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
446-
latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
464+
latents = self._pack_latents(
465+
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
466+
)
447467
return latents
448468

449469
@property
@@ -709,15 +729,17 @@ def __call__(
709729
if output_type == "latent":
710730
video = latents
711731
else:
712-
latents = self._unpack_latents(latents, latent_num_frames, latent_height, latent_width, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
713-
# unscale/denormalize the latents
714-
latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to(
715-
latents.device, latents.dtype
732+
latents = self._unpack_latents(
733+
latents,
734+
latent_num_frames,
735+
latent_height,
736+
latent_width,
737+
self.transformer_spatial_patch_size,
738+
self.transformer_temporal_patch_size,
716739
)
717-
latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to(
718-
latents.device, latents.dtype
740+
latents = self._denormalize_latents(
741+
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
719742
)
720-
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
721743
video = self.vae.decode(latents, return_dict=False)[0]
722744
video = self.video_processor.postprocess_video(video, output_type=output_type)
723745

0 commit comments

Comments
 (0)