Skip to content

Commit 817b360

Browse files
committed
support denoising strength for upscaling & video-to-video
1 parent 887613c commit 817b360

File tree

2 files changed

+55
-23
lines changed

2 files changed

+55
-23
lines changed

src/diffusers/pipelines/ltx/modeling_latent_upsampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from ...configuration_utils import ConfigMixin
19+
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...models.modeling_utils import ModelMixin
2121

2222

@@ -94,6 +94,7 @@ class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
9494
Whether to temporally upsample the latent
9595
"""
9696

97+
@register_to_config
9798
def __init__(
9899
self,
99100
in_channels: int = 128,

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ def check_inputs(
430430
video,
431431
frame_index,
432432
strength,
433+
denoise_strength,
433434
height,
434435
width,
435436
callback_on_step_end_tensor_inputs=None,
@@ -497,6 +498,9 @@ def check_inputs(
497498
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
498499
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
499500

501+
if denoise_strength < 0 or denoise_strength > 1:
502+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
503+
500504
@staticmethod
501505
def _prepare_video_ids(
502506
batch_size: int,
@@ -649,6 +653,8 @@ def prepare_latents(
649653
width: int = 704,
650654
num_frames: int = 161,
651655
num_prefix_latent_frames: int = 2,
656+
sigma: Optional[torch.Tensor] = None,
657+
latents: Optional[torch.Tensor] = None,
652658
generator: Optional[torch.Generator] = None,
653659
device: Optional[torch.device] = None,
654660
dtype: Optional[torch.dtype] = None,
@@ -658,7 +664,18 @@ def prepare_latents(
658664
latent_width = width // self.vae_spatial_compression_ratio
659665

660666
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
661-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
667+
668+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
669+
if latents is not None and sigma is not None:
670+
if latents.shape != shape:
671+
raise ValueError(
672+
f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
673+
)
674+
latents = latents.to(device=device, dtype=dtype)
675+
sigma = sigma.to(device=device, dtype=dtype)
676+
latents = sigma * noise + (1 - sigma) * latents
677+
else:
678+
latents = noise
662679

663680
if len(conditions) > 0:
664681
condition_latent_frames_mask = torch.zeros(
@@ -766,6 +783,13 @@ def prepare_latents(
766783

767784
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
768785

786+
def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
787+
num_steps = min(int(num_inference_steps * strength), num_inference_steps)
788+
start_index = max(num_inference_steps - num_steps, 0)
789+
sigmas = sigmas[start_index:]
790+
timesteps = timesteps[start_index:]
791+
return sigmas, timesteps, num_inference_steps - start_index
792+
769793
@property
770794
def guidance_scale(self):
771795
return self._guidance_scale
@@ -799,6 +823,7 @@ def __call__(
799823
video: List[PipelineImageInput] = None,
800824
frame_index: Union[int, List[int]] = 0,
801825
strength: Union[float, List[float]] = 1.0,
826+
denoise_strength: float = 1.0,
802827
prompt: Union[str, List[str]] = None,
803828
negative_prompt: Optional[Union[str, List[str]]] = None,
804829
height: int = 512,
@@ -842,6 +867,10 @@ def __call__(
842867
generation. If not provided, one has to pass `conditions`.
843868
strength (`float` or `List[float]`, *optional*):
844869
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
870+
denoise_strength (`float`, defaults to `1.0`):
871+
The strength of the noise added to the latents for editing. Higher strength leads to more noise added
872+
to the latents, therefore leading to more differences between original video and generated video. This
873+
is useful for video-to-video editing.
845874
prompt (`str` or `List[str]`, *optional*):
846875
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
847876
instead.
@@ -918,8 +947,6 @@ def __call__(
918947

919948
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
920949
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
921-
if latents is not None:
922-
raise ValueError("Passing latents is not yet supported.")
923950

924951
# 1. Check inputs. Raise error if not correct
925952
self.check_inputs(
@@ -929,6 +956,7 @@ def __call__(
929956
video=video,
930957
frame_index=frame_index,
931958
strength=strength,
959+
denoise_strength=denoise_strength,
932960
height=height,
933961
width=width,
934962
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -977,8 +1005,9 @@ def __call__(
9771005
strength = [strength] * num_conditions
9781006

9791007
device = self._execution_device
1008+
vae_dtype = self.vae.dtype
9801009

981-
# 3. Prepare text embeddings
1010+
# 3. Prepare text embeddings & conditioning image/video
9821011
(
9831012
prompt_embeds,
9841013
prompt_attention_mask,
@@ -1000,8 +1029,6 @@ def __call__(
10001029
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
10011030
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
10021031

1003-
vae_dtype = self.vae.dtype
1004-
10051032
conditioning_tensors = []
10061033
is_conditioning_image_or_video = image is not None or video is not None
10071034
if is_conditioning_image_or_video:
@@ -1032,7 +1059,24 @@ def __call__(
10321059
)
10331060
conditioning_tensors.append(condition_tensor)
10341061

1035-
# 4. Prepare latent variables
1062+
# 4. Prepare timesteps
1063+
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1064+
latent_height = height // self.vae_spatial_compression_ratio
1065+
latent_width = width // self.vae_spatial_compression_ratio
1066+
sigmas = linear_quadratic_schedule(num_inference_steps)
1067+
timesteps = sigmas * 1000
1068+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1069+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1070+
self._num_timesteps = len(timesteps)
1071+
1072+
latent_sigma = None
1073+
if denoise_strength < 1:
1074+
sigmas, timesteps, num_inference_steps = self.get_timesteps(
1075+
sigmas, timesteps, num_inference_steps, denoise_strength
1076+
)
1077+
latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
1078+
1079+
# 5. Prepare latent variables
10361080
num_channels_latents = self.transformer.config.in_channels
10371081
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
10381082
conditioning_tensors,
@@ -1043,6 +1087,8 @@ def __call__(
10431087
height=height,
10441088
width=width,
10451089
num_frames=num_frames,
1090+
sigma=latent_sigma,
1091+
latents=latents,
10461092
generator=generator,
10471093
device=device,
10481094
dtype=torch.float32,
@@ -1056,21 +1102,6 @@ def __call__(
10561102
if self.do_classifier_free_guidance:
10571103
video_coords = torch.cat([video_coords, video_coords], dim=0)
10581104

1059-
# 5. Prepare timesteps
1060-
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
1061-
latent_height = height // self.vae_spatial_compression_ratio
1062-
latent_width = width // self.vae_spatial_compression_ratio
1063-
sigmas = linear_quadratic_schedule(num_inference_steps)
1064-
timesteps = sigmas * 1000
1065-
timesteps, num_inference_steps = retrieve_timesteps(
1066-
self.scheduler,
1067-
num_inference_steps,
1068-
device,
1069-
timesteps=timesteps,
1070-
)
1071-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1072-
self._num_timesteps = len(timesteps)
1073-
10741105
# 6. Denoising loop
10751106
with self.progress_bar(total=num_inference_steps) as progress_bar:
10761107
for i, t in enumerate(timesteps):

0 commit comments

Comments
 (0)