@@ -544,6 +544,28 @@ def _prepare_non_first_frame_conditioning(
544544
545545 return latents , condition_latents , condition_latent_frames_mask
546546
547+
548+ def trim_conditioning_sequence (
549+ self , start_frame : int , sequence_num_frames : int , target_num_frames : int
550+ ):
551+ """
552+ Trim a conditioning sequence to the allowed number of frames.
553+
554+ Args:
555+ start_frame (int): The target frame number of the first frame in the sequence.
556+ sequence_num_frames (int): The number of frames in the sequence.
557+ target_num_frames (int): The target number of frames in the generated video.
558+
559+ Returns:
560+ int: updated sequence length
561+ """
562+ scale_factor = self .vae_temporal_compression_ratio
563+ num_frames = min (sequence_num_frames , target_num_frames - start_frame )
564+ # Trim down to a multiple of temporal_scale_factor frames plus 1
565+ num_frames = (num_frames - 1 ) // scale_factor * scale_factor + 1
566+ return num_frames
567+
568+
547569 def prepare_latents (
548570 self ,
549571 conditions : Union [LTXVideoCondition , List [LTXVideoCondition ]],
@@ -579,7 +601,19 @@ def prepare_latents(
579601 if condition .image is not None :
580602 data = self .video_processor .preprocess (condition .image , height , width ).unsqueeze (2 )
581603 elif condition .video is not None :
582- data = self .video_processor .preprocess_video (condition .vide , height , width )
604+ data = self .video_processor .preprocess_video (condition .video , height , width )
605+ num_frames_input = data .size (2 )
606+ num_frames_output = self .trim_conditioning_sequence (condition .frame_index , num_frames_input , num_frames )
607+ data = data [:, :, :num_frames_output ]
608+
609+ print (data .shape )
610+ print (data [0 ,0 ,:3 ,:5 ,:5 ])
611+ data_loaded = torch .load ("/raid/yiyi/LTX-Video/media_item.pt" )
612+ print (data_loaded .shape )
613+ print (data_loaded [0 ,0 ,:3 ,:5 ,:5 ])
614+ print (torch .sum ((data_loaded - data ).abs ()))
615+ print (f" dtype:{ dtype } , device:{ device } " )
616+ data = data .to (device , dtype = torch .bfloat16 )
583617 else :
584618 raise ValueError ("Either `image` or `video` must be provided in the `LTXVideoCondition`." )
585619
@@ -589,8 +623,19 @@ def prepare_latents(
589623 f"but got { data .size (2 )} frames."
590624 )
591625
626+ print (f" before encode: { data .shape } , { data .dtype } , { data .device } " )
627+
592628 condition_latents = retrieve_latents (self .vae .encode (data ), generator = generator )
593629 condition_latents = self ._normalize_latents (condition_latents , self .vae .latents_mean , self .vae .latents_std )
630+
631+ print (f" after normalize: { condition_latents .shape } " )
632+ print (condition_latents [0 ,0 ,:3 ,:5 ,:5 ])
633+ condition_latents_loaded = torch .load ("/raid/yiyi/LTX-Video/latents_normalized.pt" )
634+ print (condition_latents_loaded .shape )
635+ print (condition_latents_loaded [0 ,0 ,:3 ,:5 ,:5 ])
636+ print (torch .sum ((condition_latents_loaded - condition_latents ).abs ()))
637+ assert False
638+
594639 num_data_frames = data .size (2 )
595640 num_cond_frames = condition_latents .size (2 )
596641
0 commit comments