Skip to content

Commit 9d776e7

Browse files
committed
update
1 parent e847d85 commit 9d776e7

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,10 @@ def __call__(
565565
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
566566
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
567567
Pre-generated attention mask for negative text embeddings.
568+
decode_timestep (`float`, defaults to `0.05`):
569+
The timestep at which generated video is decoded.
570+
decode_noise_scale (`float`, defaults to `0.025`):
571+
The interpolation factor between random noise and denoised latents at the decode timestep.
568572
output_type (`str`, *optional*, defaults to `"pil"`):
569573
The output format of the generate image. Choose between
570574
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.

src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,8 @@ def __call__(
571571
prompt_attention_mask: Optional[torch.Tensor] = None,
572572
negative_prompt_embeds: Optional[torch.Tensor] = None,
573573
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
574+
decode_timestep: Union[float, List[float]] = 0.05,
575+
decode_noise_scale: Union[float, List[float]] = 0.025,
574576
output_type: Optional[str] = "pil",
575577
return_dict: bool = True,
576578
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -625,6 +627,10 @@ def __call__(
625627
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
626628
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
627629
Pre-generated attention mask for negative text embeddings.
630+
decode_timestep (`float`, defaults to `0.05`):
631+
The timestep at which generated video is decoded.
632+
decode_noise_scale (`float`, defaults to `0.025`):
633+
The interpolation factor between random noise and denoised latents at the decode timestep.
628634
output_type (`str`, *optional*, defaults to `"pil"`):
629635
The output format of the generate image. Choose between
630636
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -849,6 +855,24 @@ def __call__(
849855
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
850856
)
851857
latents = latents.to(prompt_embeds.dtype)
858+
859+
if not self.vae.config.timestep_conditioning:
860+
timestep = None
861+
else:
862+
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
863+
if not isinstance(decode_timestep, list):
864+
decode_timestep = [decode_timestep] * batch_size
865+
if decode_noise_scale is None:
866+
decode_noise_scale = decode_timestep
867+
elif not isinstance(decode_noise_scale, list):
868+
decode_noise_scale = [decode_noise_scale] * batch_size
869+
870+
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
871+
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
872+
:, None, None, None, None
873+
]
874+
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
875+
852876
video = self.vae.decode(latents, return_dict=False)[0]
853877
video = self.video_processor.postprocess_video(video, output_type=output_type)
854878

tests/pipelines/ltx/test_ltx_image2video.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,19 @@ def get_dummy_components(self):
6868

6969
torch.manual_seed(0)
7070
vae = AutoencoderKLLTXVideo(
71+
in_channels=3,
72+
out_channels=3,
7173
latent_channels=8,
7274
block_out_channels=(8, 8, 8, 8),
73-
spatio_temporal_scaling=(True, True, False, False),
75+
decoder_block_out_channels=(8, 8, 8, 8),
7476
layers_per_block=(1, 1, 1, 1, 1),
77+
decoder_layers_per_block=(1, 1, 1, 1, 1),
78+
spatio_temporal_scaling=(True, True, False, False),
79+
decoder_spatio_temporal_scaling=(True, True, False, False),
80+
decoder_inject_noise=(False, False, False, False, False),
81+
upsample_residual=(False, False, False, False),
82+
upsample_factor=(1, 1, 1, 1),
83+
timestep_conditioning=False,
7584
patch_size=1,
7685
patch_size_t=1,
7786
encoder_causal=True,

0 commit comments

Comments
 (0)