Skip to content

Commit 2bf5b72

Browse files
committed
fixes
1 parent b7434cd commit 2bf5b72

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_video2video.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,8 @@ def prepare_latents(
456456
device: Optional[torch.device] = None,
457457
generator: Optional[torch.Generator] = None,
458458
latents: Optional[torch.Tensor] = None,
459-
sigma: torch.Tensor = 1.0,
459+
timestep: Optional[torch.Tensor] = None,
460460
) -> torch.Tensor:
461-
# TODO: do we need the `conditioning_mask` here? I think `conditioning_mask` should be all ones.
462461
height = height // self.vae_spatial_compression_ratio
463462
width = width // self.vae_spatial_compression_ratio
464463

@@ -471,14 +470,9 @@ def prepare_latents(
471470
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
472471
)
473472
shape = (batch_size, num_channels_latents, num_frames, height, width)
474-
mask_shape = (batch_size, 1, num_frames, height, width)
475473

476474
if latents is not None:
477-
conditioning_mask = latents.new_ones(shape)
478-
conditioning_mask = self._pack_latents(
479-
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
480-
)
481-
return latents.to(device=device, dtype=dtype), conditioning_mask
475+
return latents.to(device=device, dtype=dtype)
482476

483477
if isinstance(generator, list):
484478
if len(generator) != batch_size:
@@ -491,30 +485,21 @@ def prepare_latents(
491485
retrieve_latents(self.vae.encode(video[i].unsqueeze(0).permute(0, 2, 1, 3, 4)), generator[i])
492486
for i in range(batch_size)
493487
]
494-
else: # `premute()` because we want `batch_size, num_channels, num_frames, height, width`
488+
else: # `premute()` because we want `batch_size, num_channels, num_frames, height, width`
495489
init_latents = [
496490
retrieve_latents(self.vae.encode(vid.unsqueeze(0).permute(0, 2, 1, 3, 4)), generator) for vid in video
497491
]
498492

499493
init_latents = torch.cat(init_latents, dim=0).to(dtype)
500494
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
501-
# `ones()` as we want to condition on all?
502-
conditioning_mask = torch.ones(mask_shape, device=device, dtype=dtype)
503-
504495
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
505-
# TODO: consider adding the noise w.r.t the flow equation? CogVideoX vid2vid
506-
# adds this noise with `add_noise()`.
507-
latents = (1 - sigma) * init_latents + sigma * noise
508-
# latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
509-
510-
conditioning_mask = self._pack_latents(
511-
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
512-
).squeeze(-1)
496+
497+
latents = self.scheduler.scale_noise(sample=init_latents, timestep=timestep, noise=noise)
513498
latents = self._pack_latents(
514499
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
515500
)
516501

517-
return latents, conditioning_mask
502+
return latents
518503

519504
@property
520505
def guidance_scale(self):
@@ -536,6 +521,16 @@ def attention_kwargs(self):
536521
def interrupt(self):
537522
return self._interrupt
538523

524+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
525+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
526+
# get the original timestep using init_timestep
527+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
528+
529+
t_start = max(num_inference_steps - init_timestep, 0)
530+
timesteps = timesteps[t_start * self.scheduler.order :]
531+
532+
return timesteps, num_inference_steps - t_start
533+
539534
@torch.no_grad()
540535
@replace_example_docstring(EXAMPLE_DOC_STRING)
541536
def __call__(
@@ -549,6 +544,7 @@ def __call__(
549544
frame_rate: int = 25,
550545
num_inference_steps: int = 50,
551546
timesteps: List[int] = None,
547+
strength: float = 0.8,
552548
guidance_scale: float = 3,
553549
num_videos_per_prompt: Optional[int] = 1,
554550
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -585,6 +581,7 @@ def __call__(
585581
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
586582
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
587583
passed will be used. Must be in descending order.
584+
strength: TODO
588585
guidance_scale (`float`, defaults to `3 `):
589586
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
590587
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -643,7 +640,7 @@ def __call__(
643640
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
644641

645642
# 1. Check inputs. Raise error if not correct
646-
# TODO: check for the `video`
643+
# TODO: check for the `video`, `strength`
647644
self.check_inputs(
648645
prompt=prompt,
649646
height=height,
@@ -726,15 +723,14 @@ def __call__(
726723
sigmas=sigmas,
727724
mu=mu,
728725
)
726+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
727+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
729728
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
730-
latent_sigma = torch.tensor(
731-
sigmas[:1].repeat(batch_size * num_videos_per_prompt), dtype=prompt_embeds.dtype, device=device
732-
)
733729
self._num_timesteps = len(timesteps)
734730

735731
# 6. Prepare latent variables
736732
num_channels_latents = self.transformer.config.in_channels
737-
latents, conditioning_mask = self.prepare_latents(
733+
latents = self.prepare_latents(
738734
video,
739735
batch_size * num_videos_per_prompt,
740736
num_channels_latents,
@@ -745,10 +741,8 @@ def __call__(
745741
device,
746742
generator,
747743
latents,
748-
sigma=latent_sigma,
744+
timestep=latent_timestep,
749745
)
750-
if self.do_classifier_free_guidance:
751-
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
752746

753747
# 7. Prepare micro-conditions
754748
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
@@ -769,8 +763,6 @@ def __call__(
769763

770764
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
771765
timestep = t.expand(latent_model_input.shape[0])
772-
# timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
773-
timestep = timestep.unsqueeze(-1)
774766
noise_pred = self.transformer(
775767
hidden_states=latent_model_input,
776768
encoder_hidden_states=prompt_embeds,
@@ -788,7 +780,6 @@ def __call__(
788780
if self.do_classifier_free_guidance:
789781
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
790782
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
791-
timestep, _ = timestep.chunk(2)
792783

793784
# compute the previous noisy sample x_t -> x_t-1
794785
noise_pred = self._unpack_latents(
@@ -807,12 +798,7 @@ def __call__(
807798
self.transformer_spatial_patch_size,
808799
self.transformer_temporal_patch_size,
809800
)
810-
811-
noise_pred = noise_pred[:, :, 1:]
812-
noise_latents = latents[:, :, 1:]
813-
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
814-
815-
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
801+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
816802
latents = self._pack_latents(
817803
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
818804
)

0 commit comments

Comments
 (0)