@@ -437,8 +437,8 @@ def check_inputs(
437437                )
438438
439439    @staticmethod  
440-     # Copied  from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents 
441-     def  _pack_latents (latents : torch .Tensor , patch_size : int  =  1 , patch_size_t : int  =  1 ) ->  torch .Tensor :
440+     # adapted  from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents 
441+     def  _pack_latents (latents : torch .Tensor , patch_size : int  =  1 , patch_size_t : int  =  1 ,  device :  torch . device   =   None ) ->  torch .Tensor :
442442        # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. 
443443        # The patch dimensions are then permuted and collapsed into the channel dimension of shape: 
444444        # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). 
@@ -447,6 +447,16 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
447447        post_patch_num_frames  =  num_frames  //  patch_size_t 
448448        post_patch_height  =  height  //  patch_size 
449449        post_patch_width  =  width  //  patch_size 
450+ 
451+         latent_sample_coords  =  torch .meshgrid (
452+             torch .arange (0 , num_frames , patch_size_t , device = device ),
453+             torch .arange (0 , height , patch_size , device = device ),
454+             torch .arange (0 , width , patch_size , device = device ),
455+         )
456+         latent_sample_coords  =  torch .stack (latent_sample_coords , dim = 0 )
457+         latent_coords  =  latent_sample_coords .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
458+         latent_coords  =  latent_coords .reshape (batch_size , - 1 , num_frames  *  height  *  width )
459+ 
450460        latents  =  latents .reshape (
451461            batch_size ,
452462            - 1 ,
@@ -458,7 +468,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
458468            patch_size ,
459469        )
460470        latents  =  latents .permute (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ).flatten (4 , 7 ).flatten (1 , 3 )
461-         return  latents 
471+         return  latents ,  latent_coords 
462472
463473    @staticmethod  
464474    # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents 
@@ -544,6 +554,25 @@ def _prepare_non_first_frame_conditioning(
544554
545555        return  latents , condition_latents , condition_latent_frames_mask 
546556
557+     def  trim_conditioning_sequence (
558+         self , start_frame : int , sequence_num_frames : int , target_num_frames : int 
559+     ):
560+         """ 
561+         Trim a conditioning sequence to the allowed number of frames. 
562+         Args: 
563+             start_frame (int): The target frame number of the first frame in the sequence. 
564+             sequence_num_frames (int): The number of frames in the sequence. 
565+             target_num_frames (int): The target number of frames in the generated video. 
566+         Returns: 
567+             int: updated sequence length 
568+         """ 
569+         scale_factor  =  self .vae_temporal_compression_ratio 
570+         num_frames  =  min (sequence_num_frames , target_num_frames  -  start_frame )
571+         # Trim down to a multiple of temporal_scale_factor frames plus 1 
572+         num_frames  =  (num_frames  -  1 ) //  scale_factor  *  scale_factor  +  1 
573+         return  num_frames 
574+ 
575+ 
547576    def  prepare_latents (
548577        self ,
549578        conditions : Union [LTXVideoCondition , List [LTXVideoCondition ]],
@@ -579,7 +608,11 @@ def prepare_latents(
579608            if  condition .image  is  not None :
580609                data  =  self .video_processor .preprocess (condition .image , height , width ).unsqueeze (2 )
581610            elif  condition .video  is  not None :
582-                 data  =  self .video_processor .preprocess_video (condition .vide , height , width )
611+                 data  =  self .video_processor .preprocess_video (condition .video , height , width )
612+                 num_frames_input  =  data .size (2 )
613+                 num_frames_output  =  self .trim_conditioning_sequence (condition .frame_index , num_frames_input , num_frames )
614+                 data  =  data [:, :, :num_frames_output ]
615+                 data  =  data .to (device , dtype = dtype )
583616            else :
584617                raise  ValueError ("Either `image` or `video` must be provided in the `LTXVideoCondition`." )
585618
@@ -599,6 +632,7 @@ def prepare_latents(
599632                    latents [:, :, :num_cond_frames ], condition_latents , condition .strength 
600633                )
601634                condition_latent_frames_mask [:, :num_cond_frames ] =  condition .strength 
635+             # YiYi TODO: code path not tested 
602636            else :
603637                if  num_data_frames  >  1 :
604638                    (
@@ -617,8 +651,8 @@ def prepare_latents(
617651                    noise  =  randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
618652                    condition_latents  =  torch .lerp (noise , condition_latents , condition .strength )
619653                    c_nlf  =  condition_latents .shape [2 ]
620-                     condition_latents  =  self ._pack_latents (
621-                         condition_latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
654+                     condition_latents ,  condition_latent_coords  =  self ._pack_latents (
655+                         condition_latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size ,  device 
622656                    )
623657                    conditioning_mask  =  torch .full (
624658                        condition_latents .shape [:2 ], condition .strength , device = device , dtype = dtype 
@@ -642,23 +676,22 @@ def prepare_latents(
642676                    extra_conditioning_rope_interpolation_scales .append (rope_interpolation_scale )
643677                    extra_conditioning_mask .append (conditioning_mask )
644678
645-         latents  =  self ._pack_latents (
646-             latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
679+         latents ,  latent_coords  =  self ._pack_latents (
680+             latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size ,  device 
647681        )
648-         rope_interpolation_scale  =  [
649-             self .vae_temporal_compression_ratio  /  frame_rate ,
650-             self .vae_spatial_compression_ratio ,
651-             self .vae_spatial_compression_ratio ,
652-         ]
653-         rope_interpolation_scale  =  (
654-             torch .tensor (rope_interpolation_scale , device = device , dtype = dtype )
655-             .view (- 1 , 1 , 1 , 1 , 1 )
656-             .repeat (1 , 1 , num_latent_frames , latent_height , latent_width )
682+         pixel_coords  =  (
683+             latent_coords 
684+             *  torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = latent_coords .device )[None , :, None ]
657685        )
658-         conditioning_mask  =  self ._pack_latents (
659-             conditioning_mask , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
686+         pixel_coords [:, 0 ] =  (pixel_coords [:, 0 ] +  1  -  self .vae_temporal_compression_ratio ).clamp (min = 0 )
687+ 
688+         rope_interpolation_scale  =  pixel_coords 
689+ 
690+         conditioning_mask  =  condition_latent_frames_mask .gather (
691+             1 , latent_coords [:, 0 ]
660692        )
661693
694+         # YiYi TODO: code path not tested 
662695        if  len (extra_conditioning_latents ) >  0 :
663696            latents  =  torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
664697            rope_interpolation_scale  =  torch .cat (
@@ -864,7 +897,7 @@ def __call__(
864897            frame_rate ,
865898            generator ,
866899            device ,
867-             torch . float32 ,
900+             prompt_embeds . dtype ,
868901        )
869902        init_latents  =  latents .clone ()
870903
@@ -955,8 +988,8 @@ def __call__(
955988                pred_latents  =  self .scheduler .step (noise_pred , t , noise_latents , return_dict = False )[0 ]
956989
957990                latents  =  torch .cat ([latents [:, :, :1 ], pred_latents ], dim = 2 )
958-                 latents  =  self ._pack_latents (
959-                     latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
991+                 latents ,  _  =  self ._pack_latents (
992+                     latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size ,  device 
960993                )
961994
962995                if  callback_on_step_end  is  not None :
0 commit comments