Skip to content

Commit 267583a

Browse files
committed
up
1 parent e202e46 commit 267583a

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,10 @@ def _prepare_non_first_frame_conditioning(
513513
frame_index: int,
514514
strength: float,
515515
num_prefix_latent_frames: int = 2,
516-
prefix_latents_mode: str = "soft",
516+
prefix_latents_mode: str = "concat",
517517
prefix_soft_conditioning_strength: float = 0.15,
518518
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
519-
num_latent_frames = latents.size(2)
519+
num_latent_frames = condition_latents.size(2)
520520

521521
if num_latent_frames < num_prefix_latent_frames:
522522
raise ValueError(
@@ -602,7 +602,7 @@ def prepare_latents(
602602
extra_conditioning_num_latents = (
603603
0 # Number of extra conditioning latents added (should be removed before decoding)
604604
)
605-
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=dtype)
605+
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
606606

607607
for condition in conditions:
608608
if condition.image is not None:
@@ -632,7 +632,7 @@ def prepare_latents(
632632
latents[:, :, :num_cond_frames], condition_latents, condition.strength
633633
)
634634
condition_latent_frames_mask[:, :num_cond_frames] = condition.strength
635-
# YiYi TODO: code path not tested
635+
636636
else:
637637
if num_data_frames > 1:
638638
(
@@ -651,47 +651,41 @@ def prepare_latents(
651651
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
652652
condition_latents = torch.lerp(noise, condition_latents, condition.strength)
653653
c_nlf = condition_latents.shape[2]
654-
condition_latents, condition_latent_coords = self._pack_latents(
654+
condition_latents, rope_interpolation_scale = self._pack_latents(
655655
condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
656656
)
657+
658+
rope_interpolation_scale = (
659+
rope_interpolation_scale *
660+
torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
661+
)
662+
rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
663+
rope_interpolation_scale[:, 0] += condition.frame_index
664+
657665
conditioning_mask = torch.full(
658666
condition_latents.shape[:2], condition.strength, device=device, dtype=dtype
659667
)
660668

661-
rope_interpolation_scale = [
662-
# TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation
663-
# scale with the grid.
664-
(self.vae_temporal_compression_ratio + condition.frame_index) / frame_rate,
665-
self.vae_spatial_compression_ratio,
666-
self.vae_spatial_compression_ratio,
667-
]
668-
rope_interpolation_scale = (
669-
torch.tensor(rope_interpolation_scale, device=device, dtype=dtype)
670-
.view(-1, 1, 1, 1, 1)
671-
.repeat(1, 1, c_nlf, latent_height, latent_width)
672-
)
673669
extra_conditioning_num_latents += condition_latents.size(1)
674670

675671
extra_conditioning_latents.append(condition_latents)
676672
extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale)
677673
extra_conditioning_mask.append(conditioning_mask)
678674

679-
latents, latent_coords = self._pack_latents(
675+
latents, rope_interpolation_scale = self._pack_latents(
680676
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
681677
)
682-
pixel_coords = (
683-
latent_coords
678+
679+
rope_interpolation_scale = (
680+
rope_interpolation_scale
684681
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
685682
)
686-
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
687-
688-
rope_interpolation_scale = pixel_coords
683+
rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
689684

690685
conditioning_mask = condition_latent_frames_mask.gather(
691686
1, latent_coords[:, 0]
692687
)
693688

694-
# YiYi TODO: code path not tested
695689
if len(extra_conditioning_latents) > 0:
696690
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
697691
rope_interpolation_scale = torch.cat(

0 commit comments

Comments
 (0)