Skip to content

Commit 81f2468

Browse files
committed
up
1 parent d0bdf4b commit 81f2468

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,28 @@ def _prepare_non_first_frame_conditioning(
544544

545545
return latents, condition_latents, condition_latent_frames_mask
546546

547+
548+
def trim_conditioning_sequence(
549+
self, start_frame: int, sequence_num_frames: int, target_num_frames: int
550+
):
551+
"""
552+
Trim a conditioning sequence to the allowed number of frames.
553+
554+
Args:
555+
start_frame (int): The target frame number of the first frame in the sequence.
556+
sequence_num_frames (int): The number of frames in the sequence.
557+
target_num_frames (int): The target number of frames in the generated video.
558+
559+
Returns:
560+
int: updated sequence length
561+
"""
562+
scale_factor = self.vae_temporal_compression_ratio
563+
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
564+
# Trim down to a multiple of temporal_scale_factor frames plus 1
565+
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
566+
return num_frames
567+
568+
547569
def prepare_latents(
548570
self,
549571
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]],
@@ -579,7 +601,19 @@ def prepare_latents(
579601
if condition.image is not None:
580602
data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2)
581603
elif condition.video is not None:
582-
data = self.video_processor.preprocess_video(condition.vide, height, width)
604+
data = self.video_processor.preprocess_video(condition.video, height, width)
605+
num_frames_input = data.size(2)
606+
num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames)
607+
data = data[:, :, :num_frames_output]
608+
609+
print(data.shape)
610+
print(data[0,0,:3,:5,:5])
611+
data_loaded = torch.load("/raid/yiyi/LTX-Video/media_item.pt")
612+
print(data_loaded.shape)
613+
print(data_loaded[0,0,:3,:5,:5])
614+
print(torch.sum((data_loaded - data).abs()))
615+
print(f" dtype:{dtype}, device:{device}")
616+
data = data.to(device, dtype=torch.bfloat16)
583617
else:
584618
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
585619

@@ -589,8 +623,19 @@ def prepare_latents(
589623
f"but got {data.size(2)} frames."
590624
)
591625

626+
print(f" before encode: {data.shape}, {data.dtype}, {data.device}")
627+
592628
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
593629
condition_latents = self._normalize_latents(condition_latents, self.vae.latents_mean, self.vae.latents_std)
630+
631+
print(f" after normalize: {condition_latents.shape}")
632+
print(condition_latents[0,0,:3,:5,:5])
633+
condition_latents_loaded = torch.load("/raid/yiyi/LTX-Video/latents_normalized.pt")
634+
print(condition_latents_loaded.shape)
635+
print(condition_latents_loaded[0,0,:3,:5,:5])
636+
print(torch.sum((condition_latents_loaded - condition_latents).abs()))
637+
assert False
638+
594639
num_data_frames = data.size(2)
595640
num_cond_frames = condition_latents.size(2)
596641

0 commit comments

Comments
 (0)