@@ -464,12 +464,7 @@ def check_inputs(
464464
465465    @staticmethod  
466466    # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents 
467-     def  _prepare_video_ids (latents : torch .Tensor , scale_factor : int  =  32 , scale_factor_t : int  =  8 , patch_size : int  =  1 , patch_size_t : int  =  1 , frame_index : int  =  0 , device : torch .device  =  None , return_unscaled_coords : bool  =  False ) ->  torch .Tensor :
468-         # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. 
469-         # The patch dimensions are then permuted and collapsed into the channel dimension of shape: 
470-         # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). 
471-         # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features 
472-         batch_size , num_channels , num_frames , height , width  =  latents .shape 
467+     def  _prepare_video_ids (batch_size : int , num_frames : int , height : int , width : int , patch_size : int  =  1 , patch_size_t : int  =  1 , device : torch .device  =  None ) ->  torch .Tensor :
473468
474469        latent_sample_coords  =  torch .meshgrid (
475470            torch .arange (0 , num_frames , patch_size_t , device = device ),
@@ -481,17 +476,21 @@ def _prepare_video_ids(latents: torch.Tensor, scale_factor: int = 32, scale_fact
481476        latent_coords  =  latent_sample_coords .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
482477        latent_coords  =  latent_coords .reshape (batch_size , - 1 , num_frames  *  height  *  width )
483478
479+         return  latent_coords 
480+     
481+ 
482+     @staticmethod  
483+     # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents 
484+     def  _scale_video_ids (video_ids : torch .Tensor , scale_factor : int  =  32 , scale_factor_t : int  =  8 , frame_index : int  =  0 , device : torch .device  =  None ) ->  torch .Tensor :
485+ 
484486        scaled_latent_coords  =  (
485-             latent_coords  *  
486-             torch .tensor ([scale_factor_t , scale_factor , scale_factor ], device = latent_coords .device )[None , :, None ]
487+             video_ids  *  
488+             torch .tensor ([scale_factor_t , scale_factor , scale_factor ], device = video_ids .device )[None , :, None ]
487489        )
488490        scaled_latent_coords [:, 0 ] =  (scaled_latent_coords [:, 0 ] +  1  -  scale_factor_t ).clamp (min = 0 )  
489491        scaled_latent_coords [:, 0 ] +=  frame_index 
490492
491-         if  return_unscaled_coords :
492-             return  latent_coords , scaled_latent_coords 
493-         else :
494-             return  scaled_latent_coords 
493+         return  scaled_latent_coords 
495494
496495    @staticmethod  
497496    # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents 
@@ -622,7 +621,7 @@ def prepare_latents(
622621
623622        shape  =  (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
624623        latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
625-         #  latents = torch.load("/raid/yiyi/LTX-Video/init_latents.pt").to(device, dtype=dtype)
624+         latents  =  torch .load ("/raid/yiyi/LTX-Video/init_latents.pt" ).to (device , dtype = dtype )
626625
627626        condition_latent_frames_mask  =  torch .zeros ((batch_size , num_latent_frames ), device = device , dtype = torch .float32 )
628627
@@ -632,8 +631,8 @@ def prepare_latents(
632631        extra_conditioning_num_latents  =  0 
633632        for  data , strength , frame_index  in  zip (conditions , condition_strength , condition_frame_index ):
634633            condition_latents  =  retrieve_latents (self .vae .encode (data ), generator = generator )
635-             # condition_latents = torch.load("/raid/yiyi/LTX-Video/latents_before_normalize.pt").to(device, dtype=dtype) 
636634            condition_latents  =  self ._normalize_latents (condition_latents , self .vae .latents_mean , self .vae .latents_std )
635+             condition_latents  =  torch .load ("/raid/yiyi/LTX-Video/conditioning_latents.pt" ).to (device , dtype = dtype )
637636
638637            num_data_frames  =  data .size (2 )
639638            num_cond_frames  =  condition_latents .size (2 )
@@ -662,10 +661,11 @@ def prepare_latents(
662661
663662
664663                noise  =  randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
665-                 #  noise = torch.load("/raid/yiyi/LTX-Video/noise.pt").to(device, dtype=dtype)
664+                 noise  =  torch .load ("/raid/yiyi/LTX-Video/noise.pt" ).to (device , dtype = dtype )
666665                condition_latents  =  torch .lerp (noise , condition_latents , strength )
667666
668-                 condition_video_ids  =  self ._prepare_video_ids (condition_latents , scale_factor = self .vae_spatial_compression_ratio , scale_factor_t = self .vae_temporal_compression_ratio ,  patch_size = self .transformer_spatial_patch_size , patch_size_t = self .transformer_temporal_patch_size , frame_index = frame_index , device = device )
667+                 condition_video_ids  =  self ._prepare_video_ids (batch_size , condition_latents .size (2 ), latent_height , latent_width , patch_size = self .transformer_spatial_patch_size , patch_size_t = self .transformer_temporal_patch_size , device = device )
668+                 condition_video_ids  =  self ._scale_video_ids (condition_video_ids , scale_factor = self .vae_spatial_compression_ratio , scale_factor_t = self .vae_temporal_compression_ratio , frame_index = frame_index , device = device )
669669                condition_latents  =  self ._pack_latents (condition_latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size , device )
670670                condition_conditioning_mask  =  torch .full (condition_latents .shape [:2 ], strength , device = device , dtype = dtype )
671671
@@ -675,7 +675,8 @@ def prepare_latents(
675675                extra_conditioning_mask .append (condition_conditioning_mask )
676676                extra_conditioning_num_latents  +=  condition_latents .size (1 )
677677
678-         video_ids , video_ids_scaled  =  self ._prepare_video_ids (latents , scale_factor_t  =  self .vae_temporal_compression_ratio , scale_factor  =  self .vae_spatial_compression_ratio , patch_size_t  =  self .transformer_temporal_patch_size , patch_size  =  self .transformer_spatial_patch_size , device = device , return_unscaled_coords = True )
678+         video_ids  =  self ._prepare_video_ids (batch_size , num_latent_frames , latent_height , latent_width , patch_size_t  =  self .transformer_temporal_patch_size , patch_size  =  self .transformer_spatial_patch_size , device = device )
679+         video_ids_scaled  =  self ._scale_video_ids (video_ids , scale_factor = self .vae_spatial_compression_ratio , scale_factor_t = self .vae_temporal_compression_ratio , frame_index = 0 , device = device )
679680        latents  =  self ._pack_latents (latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size , device )
680681        conditioning_mask  =  condition_latent_frames_mask .gather (
681682            1 , video_ids [:, 0 ]
@@ -916,6 +917,10 @@ def __call__(
916917            device = device ,
917918            dtype = prompt_embeds .dtype ,
918919        )
920+ 
921+         video_coords  =  video_coords .float ()
922+         video_coords [:, 0 ] =  video_coords [:, 0 ] *  (1.0  /  frame_rate )
923+        
919924        init_latents  =  latents .clone ()
920925
921926        if  self .do_classifier_free_guidance :
@@ -949,7 +954,7 @@ def __call__(
949954                    latents  =  self .add_noise_to_image_conditioning_latents (
950955                        t / 1000.0 ,
951956                        init_latents ,
952-                         latents . float () ,
957+                         latents ,
953958                        image_cond_noise_scale ,
954959                        conditioning_mask ,
955960                        generator ,
@@ -961,7 +966,7 @@ def __call__(
961966                latent_model_input  =  latent_model_input .to (prompt_embeds .dtype )
962967
963968                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
964-                 timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 )
969+                 timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ). float () 
965970                timestep  =  torch .min (timestep , (1  -  conditioning_mask_model_input ) *  1000.0 )
966971
967972                noise_pred  =  self .transformer (
@@ -973,12 +978,13 @@ def __call__(
973978                    attention_kwargs = attention_kwargs ,
974979                    return_dict = False ,
975980                )[0 ]
981+ 
976982                if  self .do_classifier_free_guidance :
977983                    noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
978984                    noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
979985                    timestep , _  =  timestep .chunk (2 )
980986
981-                 denoised_latents  =  self .scheduler .step (noise_pred , timestep , latents , return_dict = False )[0 ]
987+                 denoised_latents  =  self .scheduler .step (- noise_pred , timestep , latents , return_dict = False )[0 ]
982988                t_eps  =  1e-6 
983989                tokens_to_denoise_mask  =  (t / 1000  -  t_eps  <  (1.0  -  conditioning_mask )).unsqueeze (- 1 )
984990                latents  =  torch .where (tokens_to_denoise_mask , denoised_latents , latents )
0 commit comments