Skip to content

Commit cbc035d

Browse files
committed
make it work
1 parent 6f7e837 commit cbc035d

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,7 @@ def check_inputs(
464464

465465
@staticmethod
466466
# adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
467-
def _prepare_video_ids(latents: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, patch_size: int = 1, patch_size_t: int = 1, frame_index: int = 0, device: torch.device = None, return_unscaled_coords: bool = False) -> torch.Tensor:
468-
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
469-
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
470-
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
471-
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
472-
batch_size, num_channels, num_frames, height, width = latents.shape
467+
def _prepare_video_ids(batch_size: int, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor:
473468

474469
latent_sample_coords = torch.meshgrid(
475470
torch.arange(0, num_frames, patch_size_t, device=device),
@@ -481,17 +476,21 @@ def _prepare_video_ids(latents: torch.Tensor, scale_factor: int = 32, scale_fact
481476
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
482477
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
483478

479+
return latent_coords
480+
481+
482+
@staticmethod
483+
# adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
484+
def _scale_video_ids(video_ids: torch.Tensor, scale_factor: int = 32, scale_factor_t: int = 8, frame_index: int = 0, device: torch.device = None) -> torch.Tensor:
485+
484486
scaled_latent_coords = (
485-
latent_coords *
486-
torch.tensor([scale_factor_t, scale_factor, scale_factor], device=latent_coords.device)[None, :, None]
487+
video_ids *
488+
torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None]
487489
)
488490
scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0)
489491
scaled_latent_coords[:, 0] += frame_index
490492

491-
if return_unscaled_coords:
492-
return latent_coords, scaled_latent_coords
493-
else:
494-
return scaled_latent_coords
493+
return scaled_latent_coords
495494

496495
@staticmethod
497496
# adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
@@ -622,7 +621,7 @@ def prepare_latents(
622621

623622
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
624623
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
625-
# latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype)
624+
latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype)
626625

627626
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
628627

@@ -632,8 +631,8 @@ def prepare_latents(
632631
extra_conditioning_num_latents = 0
633632
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
634633
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
635-
# condition_latents = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt").to(device, dtype=dtype)
636634
condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std)
635+
condition_latents = torch.load("/raid/yiyi/LTX-Video/conditioning_latents.pt").to(device, dtype=dtype)
637636

638637
num_data_frames = data.size(2)
639638
num_cond_frames = condition_latents.size(2)
@@ -662,10 +661,11 @@ def prepare_latents(
662661

663662

664663
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
665-
# noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype)
664+
noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype)
666665
condition_latents = torch.lerp(noise, condition_latents, strength)
667666

668-
condition_video_ids = self._prepare_video_ids(condition_latents, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, patch_size=self.transformer_spatial_patch_size, patch_size_t=self.transformer_temporal_patch_size, frame_index=frame_index, device=device)
667+
condition_video_ids = self._prepare_video_ids(batch_size, condition_latents.size(2), latent_height, latent_width, patch_size=self.transformer_spatial_patch_size, patch_size_t=self.transformer_temporal_patch_size, device=device)
668+
condition_video_ids = self._scale_video_ids(condition_video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, frame_index=frame_index, device=device)
669669
condition_latents = self._pack_latents(condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device)
670670
condition_conditioning_mask = torch.full(condition_latents.shape[:2], strength, device=device, dtype=dtype)
671671

@@ -675,7 +675,8 @@ def prepare_latents(
675675
extra_conditioning_mask.append(condition_conditioning_mask)
676676
extra_conditioning_num_latents += condition_latents.size(1)
677677

678-
video_ids, video_ids_scaled = self._prepare_video_ids(latents, scale_factor_t = self.vae_temporal_compression_ratio, scale_factor = self.vae_spatial_compression_ratio, patch_size_t = self.transformer_temporal_patch_size, patch_size = self.transformer_spatial_patch_size, device=device, return_unscaled_coords=True)
678+
video_ids = self._prepare_video_ids(batch_size, num_latent_frames, latent_height, latent_width, patch_size_t = self.transformer_temporal_patch_size, patch_size = self.transformer_spatial_patch_size, device=device)
679+
video_ids_scaled = self._scale_video_ids(video_ids, scale_factor=self.vae_spatial_compression_ratio, scale_factor_t=self.vae_temporal_compression_ratio, frame_index=0, device=device)
679680
latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device)
680681
conditioning_mask = condition_latent_frames_mask.gather(
681682
1, video_ids[:, 0]
@@ -916,6 +917,10 @@ def __call__(
916917
device=device,
917918
dtype=prompt_embeds.dtype,
918919
)
920+
921+
video_coords = video_coords.float()
922+
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
923+
919924
init_latents = latents.clone()
920925

921926
if self.do_classifier_free_guidance:
@@ -949,7 +954,7 @@ def __call__(
949954
latents = self.add_noise_to_image_conditioning_latents(
950955
t/1000.0,
951956
init_latents,
952-
latents.float(),
957+
latents,
953958
image_cond_noise_scale,
954959
conditioning_mask,
955960
generator,
@@ -961,7 +966,7 @@ def __call__(
961966
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
962967

963968
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
964-
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1)
969+
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
965970
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
966971

967972
noise_pred = self.transformer(
@@ -973,12 +978,13 @@ def __call__(
973978
attention_kwargs=attention_kwargs,
974979
return_dict=False,
975980
)[0]
981+
976982
if self.do_classifier_free_guidance:
977983
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
978984
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
979985
timestep, _ = timestep.chunk(2)
980986

981-
denoised_latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
987+
denoised_latents = self.scheduler.step(-noise_pred, timestep, latents, return_dict=False)[0]
982988
t_eps = 1e-6
983989
tokens_to_denoise_mask = (t/1000 - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
984990
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)

0 commit comments

Comments
 (0)