Skip to content

Commit 16c1467

Browse files
committed
up
1 parent 267583a commit 16c1467

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def prepare_latents(
657657

658658
rope_interpolation_scale = (
659659
rope_interpolation_scale *
660-
torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
660+
torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None]
661661
)
662662
rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
663663
rope_interpolation_scale[:, 0] += condition.frame_index
@@ -675,17 +675,16 @@ def prepare_latents(
675675
latents, rope_interpolation_scale = self._pack_latents(
676676
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device
677677
)
678+
conditioning_mask = condition_latent_frames_mask.gather(
679+
1, rope_interpolation_scale[:, 0]
680+
)
678681

679682
rope_interpolation_scale = (
680683
rope_interpolation_scale
681-
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=latent_coords.device)[None, :, None]
684+
* torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None]
682685
)
683686
rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0)
684687

685-
conditioning_mask = condition_latent_frames_mask.gather(
686-
1, latent_coords[:, 0]
687-
)
688-
689688
if len(extra_conditioning_latents) > 0:
690689
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
691690
rope_interpolation_scale = torch.cat(

0 commit comments

Comments
 (0)