2121from  transformers  import  T5EncoderModel , T5TokenizerFast 
2222
2323from  ...callbacks  import  MultiPipelineCallbacks , PipelineCallback 
24+ from  ...image_processor  import  PipelineImageInput 
2425from  ...loaders  import  FromSingleFileMixin , LTXVideoLoraLoaderMixin 
2526from  ...models .autoencoders  import  AutoencoderKLLTXVideo 
2627from  ...models .transformers  import  LTXVideoTransformer3DModel 
4546    Examples: 
4647        ```py 
4748        >>> import torch 
48-         >>> from diffusers import LTXImageToVideoPipeline  
49+         >>> from diffusers import LTXConditionPipeline  
4950        >>> from diffusers.utils import export_to_video, load_image 
5051
51-         >>> pipe = LTXImageToVideoPipeline .from_pretrained("Lightricks/LTX-Video ", torch_dtype=torch.bfloat16) 
52+         >>> pipe = LTXConditionPipeline .from_pretrained("YiYiXu/ltx-95 ", torch_dtype=torch.bfloat16) 
5253        >>> pipe.to("cuda") 
53- 
5454        >>> image = load_image( 
5555        ...     "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" 
5656        ... ) 
@@ -405,6 +405,11 @@ def encode_prompt(
405405    def  check_inputs (
406406        self ,
407407        prompt ,
408+         conditions ,
409+         image ,
410+         video ,
411+         frame_index ,
412+         strength ,
408413        height ,
409414        width ,
410415        callback_on_step_end_tensor_inputs = None ,
@@ -455,6 +460,26 @@ def check_inputs(
455460                    f" { negative_prompt_attention_mask .shape }  
456461                )
457462
463+         if  conditions  is  not None  and  (image  is  not None  or  video  is  not None ):
464+             raise  ValueError ("If `conditions` is provided, `image` and `video` must not be provided." )
465+ 
466+         if  conditions  is  None  and  (image  is  None  and  video  is  None ):
467+             raise  ValueError ("If `conditions` is not provided, `image` or `video` must be provided." )
468+ 
469+         if  conditions  is  None :
470+             if  isinstance (image , list ) and  isinstance (frame_index , list ) and  len (image ) !=  len (frame_index ):
471+                 raise  ValueError (
472+                     "If `conditions` is not provided, `image` and `frame_index` must be of the same length." 
473+                 )
474+             elif  isinstance (image , list ) and  isinstance (strength , list ) and  len (image ) !=  len (strength ):
475+                 raise  ValueError ("If `conditions` is not provided, `image` and `strength` must be of the same length." )
476+             elif  isinstance (video , list ) and  isinstance (frame_index , list ) and  len (video ) !=  len (frame_index ):
477+                 raise  ValueError (
478+                     "If `conditions` is not provided, `video` and `frame_index` must be of the same length." 
479+                 )
480+             elif  isinstance (video , list ) and  isinstance (strength , list ) and  len (video ) !=  len (strength ):
481+                 raise  ValueError ("If `conditions` is not provided, `video` and `strength` must be of the same length." )
482+ 
458483    @staticmethod  
459484    def  _prepare_video_ids (
460485        batch_size : int ,
@@ -699,7 +724,8 @@ def prepare_latents(
699724            patch_size = self .transformer_spatial_patch_size ,
700725            device = device ,
701726        )
702-         video_ids_scaled  =  self ._scale_video_ids (
727+         conditioning_mask  =  condition_latent_frames_mask .gather (1 , video_ids [:, 0 ])
728+         video_ids  =  self ._scale_video_ids (
703729            video_ids ,
704730            scale_factor = self .vae_spatial_compression_ratio ,
705731            scale_factor_t = self .vae_temporal_compression_ratio ,
@@ -709,11 +735,10 @@ def prepare_latents(
709735        latents  =  self ._pack_latents (
710736            latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
711737        )
712-         conditioning_mask  =  condition_latent_frames_mask .gather (1 , video_ids [:, 0 ])
713738
714739        if  len (extra_conditioning_latents ) >  0 :
715740            latents  =  torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
716-             video_ids  =  torch .cat ([* extra_conditioning_video_ids , video_ids_scaled ], dim = 2 )
741+             video_ids  =  torch .cat ([* extra_conditioning_video_ids , video_ids ], dim = 2 )
717742            conditioning_mask  =  torch .cat ([* extra_conditioning_mask , conditioning_mask ], dim = 1 )
718743
719744        return  latents , conditioning_mask , video_ids , extra_conditioning_num_latents 
@@ -742,7 +767,11 @@ def interrupt(self):
742767    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
743768    def  __call__ (
744769        self ,
745-         conditions : Union [LTXVideoCondition , List [LTXVideoCondition ]],
770+         conditions : Union [LTXVideoCondition , List [LTXVideoCondition ]] =  None ,
771+         image : Union [PipelineImageInput , List [PipelineImageInput ]] =  None ,
772+         video : List [PipelineImageInput ] =  None ,
773+         frame_index : Union [int , List [int ]] =  0 ,
774+         strength : Union [float , List [float ]] =  1.0 ,
746775        prompt : Union [str , List [str ]] =  None ,
747776        negative_prompt : Optional [Union [str , List [str ]]] =  None ,
748777        height : int  =  512 ,
@@ -773,8 +802,19 @@ def __call__(
773802        Function invoked when calling the pipeline for generation. 
774803
775804        Args: 
776-             conditions (`List[LTXVideoCondition]`): 
777-                 The list of frame-conditioning items for the video generation. 
805+             conditions (`List[LTXVideoCondition], *optional*`): 
806+                 The list of frame-conditioning items for the video generation.If not provided, conditions will be 
807+                 created using `image`, `video`, `frame_index` and `strength`. 
808+             image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): 
809+                 The image or images to condition the video generation. If not provided, one has to pass `video` or 
810+                 `conditions`. 
811+             video (`List[PipelineImageInput]`, *optional*): 
812+                 The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. 
813+             frame_index (`int` or `List[int]`, *optional*): 
814+                 The frame index or frame indices at which the image or video will conditionally effect the video 
815+                 generation. If not provided, one has to pass `conditions`. 
816+             strength (`float` or `List[float]`, *optional*): 
817+                 The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. 
778818            prompt (`str` or `List[str]`, *optional*): 
779819                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
780820                instead. 
@@ -857,6 +897,11 @@ def __call__(
857897        # 1. Check inputs. Raise error if not correct 
858898        self .check_inputs (
859899            prompt = prompt ,
900+             conditions = conditions ,
901+             image = image ,
902+             video = video ,
903+             frame_index = frame_index ,
904+             strength = strength ,
860905            height = height ,
861906            width = width ,
862907            callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
@@ -878,6 +923,31 @@ def __call__(
878923        else :
879924            batch_size  =  prompt_embeds .shape [0 ]
880925
926+         if  conditions  is  not None :
927+             if  not  isinstance (conditions , list ):
928+                 conditions  =  [conditions ]
929+ 
930+             strength  =  [condition .strength  for  condition  in  conditions ]
931+             frame_index  =  [condition .frame_index  for  condition  in  conditions ]
932+             image  =  [condition .image  for  condition  in  conditions ]
933+             video  =  [condition .video  for  condition  in  conditions ]
934+         else :
935+             if  not  isinstance (image , list ):
936+                 image  =  [image ]
937+                 num_conditions  =  1 
938+             elif  isinstance (image , list ):
939+                 num_conditions  =  len (image )
940+             if  not  isinstance (video , list ):
941+                 video  =  [video ]
942+                 num_conditions  =  1 
943+             elif  isinstance (video , list ):
944+                 num_conditions  =  len (video )
945+ 
946+             if  not  isinstance (frame_index , list ):
947+                 frame_index  =  [frame_index ] *  num_conditions 
948+             if  not  isinstance (strength , list ):
949+                 strength  =  [strength ] *  num_conditions 
950+ 
881951        device  =  self ._execution_device 
882952
883953        # 3. Prepare text embeddings 
@@ -905,17 +975,20 @@ def __call__(
905975        vae_dtype  =  self .vae .dtype 
906976
907977        conditioning_tensors  =  []
908-         conditioning_strengths  =  []
909-         conditioning_start_frames  =  []
910- 
911-         for  condition  in  conditions :
912-             if  condition .image  is  not None :
913-                 condition_tensor  =  self .video_processor .preprocess (condition .image , height , width ).unsqueeze (2 )
914-             elif  condition .video  is  not None :
915-                 condition_tensor  =  self .video_processor .preprocess_video (condition .video , height , width )
978+         for  condition_image , condition_video , condition_frame_index , condition_strength  in  zip (
979+             image , video , frame_index , strength 
980+         ):
981+             if  condition_image  is  not None :
982+                 condition_tensor  =  (
983+                     self .video_processor .preprocess (condition_image , height , width )
984+                     .unsqueeze (2 )
985+                     .to (device , dtype = vae_dtype )
986+                 )
987+             elif  condition_video  is  not None :
988+                 condition_tensor  =  self .video_processor .preprocess_video (condition_video , height , width )
916989                num_frames_input  =  condition_tensor .size (2 )
917990                num_frames_output  =  self .trim_conditioning_sequence (
918-                     condition . frame_index , num_frames_input , num_frames 
991+                     condition_frame_index , num_frames_input , num_frames 
919992                )
920993                condition_tensor  =  condition_tensor [:, :, :num_frames_output ]
921994                condition_tensor  =  condition_tensor .to (device , dtype = vae_dtype )
@@ -928,15 +1001,13 @@ def __call__(
9281001                    f"but got { condition_tensor .size (2 )}  
9291002                )
9301003            conditioning_tensors .append (condition_tensor )
931-             conditioning_strengths .append (condition .strength )
932-             conditioning_start_frames .append (condition .frame_index )
9331004
9341005        # 4. Prepare latent variables 
9351006        num_channels_latents  =  self .transformer .config .in_channels 
9361007        latents , conditioning_mask , video_coords , extra_conditioning_num_latents  =  self .prepare_latents (
9371008            conditioning_tensors ,
938-             conditioning_strengths ,
939-             conditioning_start_frames ,
1009+             strength ,
1010+             frame_index ,
9401011            batch_size = batch_size  *  num_videos_per_prompt ,
9411012            num_channels_latents = num_channels_latents ,
9421013            height = height ,
@@ -1015,9 +1086,10 @@ def __call__(
10151086                    noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
10161087                    timestep , _  =  timestep .chunk (2 )
10171088
1018-                 denoised_latents  =  self .scheduler .step (- noise_pred , timestep , latents , return_dict = False )[0 ]
1019-                 t_eps  =  1e-6 
1020-                 tokens_to_denoise_mask  =  (t  /  1000  -  t_eps  <  (1.0  -  conditioning_mask )).unsqueeze (- 1 )
1089+                 denoised_latents  =  self .scheduler .step (
1090+                     - noise_pred , t , latents , per_token_timesteps = timestep , return_dict = False 
1091+                 )[0 ]
1092+                 tokens_to_denoise_mask  =  (t  /  1000  -  1e-6  <  (1.0  -  conditioning_mask )).unsqueeze (- 1 )
10211093                latents  =  torch .where (tokens_to_denoise_mask , denoised_latents , latents )
10221094
10231095                if  callback_on_step_end  is  not None :
0 commit comments