1919from  transformers  import  T5EncoderModel , T5TokenizerFast 
2020
2121from  ...callbacks  import  MultiPipelineCallbacks , PipelineCallback 
22- from  ...image_processor  import  PipelineImageInput 
2322from  ...loaders  import  FromSingleFileMixin , LTXVideoLoraLoaderMixin 
2423from  ...models .autoencoders  import  AutoencoderKLLTXVideo 
2524from  ...models .transformers  import  LTXVideoTransformer3DModel 
@@ -559,7 +558,7 @@ def _extract_spatial_tile(self, latents, v_start, v_end, h_start, h_end):
559558        """Extract spatial tiles from all inputs for a given spatial region.""" 
560559        tile_latents  =  latents [:, :, :, v_start :v_end , h_start :h_end ]
561560        return  tile_latents 
562-      
561+ 
563562    def  _select_latents (self , latents : torch .Tensor , start_index : int , end_index : int ) ->  torch .Tensor :
564563        num_frames  =  latents .shape [2 ]
565564        start_idx  =  num_frames  +  start_index  if  start_index  <  0  else  start_index 
@@ -570,11 +569,9 @@ def _select_latents(self, latents: torch.Tensor, start_index: int, end_index: in
570569            start_idx  =  min (start_idx , end_idx )
571570        latents  =  latents [:, :, start_idx  : end_idx  +  1 , :, :].clone ()
572571        return  latents 
573-      
572+ 
574573    @staticmethod  
575-     def  _create_spatial_weights (
576-         latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap 
577-     ):
574+     def  _create_spatial_weights (latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap ):
578575        """Create blending weights for spatial tiles.""" 
579576        tile_weights  =  torch .ones_like (latents )
580577
@@ -658,7 +655,7 @@ def prepare_latents(
658655        latent_width  =  width  //  self .vae_spatial_compression_ratio 
659656        shape  =  (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
660657        noise  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
661-          
658+ 
662659        if  latents  is  not None :
663660            if  latents .shape  !=  shape :
664661                raise  ValueError (
@@ -678,11 +675,7 @@ def prepare_latents(
678675            device = device ,
679676        )
680677        video_ids  =  self ._scale_video_ids (
681-             video_ids ,
682-             self .vae_spatial_compression_ratio ,
683-             self .vae_temporal_compression_ratio ,
684-             0 ,
685-             device 
678+             video_ids , self .vae_spatial_compression_ratio , self .vae_temporal_compression_ratio , 0 , device 
686679        )
687680
688681        return  latents , video_ids 
@@ -857,7 +850,9 @@ def __call__(
857850        if  isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
858851            callback_on_step_end_tensor_inputs  =  callback_on_step_end .tensor_inputs 
859852        if  horizontal_tiles  >  1  or  vertical_tiles  >  1 :
860-             raise  ValueError ("Setting `horizontal_tiles` or `vertical_tiles` to a value greater than 0 is not supported yet." )
853+             raise  ValueError (
854+                 "Setting `horizontal_tiles` or `vertical_tiles` to a value greater than 0 is not supported yet." 
855+             )
861856
862857        # 1. Check inputs. Raise error if not correct 
863858        self .check_inputs (
@@ -967,11 +962,14 @@ def __call__(
967962                first_tile_out_latents  =  None 
968963
969964                for  index_temporal_tile , (start_index , end_index ) in  enumerate (
970-                     zip (range (0 , temporal_range_max , temporal_range_step ),
971-                         range (temporal_tile_size , temporal_range_max , temporal_range_step )
965+                     zip (
966+                         range (0 , temporal_range_max , temporal_range_step ),
967+                         range (temporal_tile_size , temporal_range_max , temporal_range_step ),
972968                    )
973969                ):
974-                     latent_chunk  =  self ._select_latents (tile_latents , start_index , min (end_index  -  1 , tile_latents .shape [2 ] -  1 ))
970+                     latent_chunk  =  self ._select_latents (
971+                         tile_latents , start_index , min (end_index  -  1 , tile_latents .shape [2 ] -  1 )
972+                     )
975973                    latent_tile_num_frames  =  latent_chunk .shape [2 ]
976974
977975                    if  start_index  >  0 :
@@ -981,12 +979,14 @@ def __call__(
981979                        total_latent_num_frames  =  last_latent_tile_num_frames  +  latent_tile_num_frames 
982980
983981                        conditioning_mask  =  torch .zeros (
984-                             (batch_size , total_latent_num_frames ), dtype = torch .float32 , device = device ,
982+                             (batch_size , total_latent_num_frames ),
983+                             dtype = torch .float32 ,
984+                             device = device ,
985985                        )
986986                        conditioning_mask [:, :last_latent_tile_num_frames ] =  1.0 
987987                    else :
988988                        total_latent_num_frames  =  latent_tile_num_frames 
989-                      
989+ 
990990                    latent_chunk  =  self ._pack_latents (
991991                        latent_chunk ,
992992                        self .transformer_spatial_patch_size ,
@@ -1002,29 +1002,31 @@ def __call__(
10021002                        patch_size = self .transformer_spatial_patch_size ,
10031003                        device = device ,
10041004                    )
1005-                      
1005+ 
10061006                    if  start_index  >  0 :
10071007                        conditioning_mask  =  conditioning_mask .gather (1 , video_ids [:, 0 ])
10081008                        conditioning_mask_model_input  =  (
10091009                            torch .cat ([conditioning_mask , conditioning_mask ])
10101010                            if  self .do_classifier_free_guidance 
10111011                            else  conditioning_mask 
10121012                        )
1013-                      
1013+ 
10141014                    video_ids  =  self ._scale_video_ids (
10151015                        video_ids ,
10161016                        scale_factor = self .vae_spatial_compression_ratio ,
10171017                        scale_factor_t = self .vae_temporal_compression_ratio ,
10181018                        frame_index = 0 ,
1019-                         device = device 
1019+                         device = device , 
10201020                    )
10211021                    video_ids  =  video_ids .float ()
10221022                    video_ids [:, 0 ] =  video_ids [:, 0 ] *  (1.0  /  frame_rate )
10231023                    if  self .do_classifier_free_guidance :
10241024                        video_ids  =  torch .cat ([video_ids , video_ids ], dim = 0 )
10251025
10261026                    # Set timesteps 
1027-                     inner_timesteps , inner_num_inference_steps  =  retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
1027+                     inner_timesteps , inner_num_inference_steps  =  retrieve_timesteps (
1028+                         self .scheduler , num_inference_steps , device , timesteps 
1029+                     )
10281030                    sigmas  =  self .scheduler .sigmas 
10291031                    num_warmup_steps  =  max (len (inner_timesteps ) -  inner_num_inference_steps  *  self .scheduler .order , 0 )
10301032                    self ._num_timesteps  =  len (inner_timesteps )
@@ -1035,7 +1037,9 @@ def __call__(
10351037                                continue 
10361038
10371039                            self ._current_timestep  =  t 
1038-                             latent_model_input  =  torch .cat ([latent_chunk ] *  2 ) if  self .do_classifier_free_guidance  else  latent_chunk 
1040+                             latent_model_input  =  (
1041+                                 torch .cat ([latent_chunk ] *  2 ) if  self .do_classifier_free_guidance  else  latent_chunk 
1042+                             )
10391043                            latent_model_input  =  latent_model_input .to (prompt_embeds .dtype )
10401044                            timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
10411045                            if  start_index  >  0 :
@@ -1054,7 +1058,9 @@ def __call__(
10541058
10551059                            if  self .do_classifier_free_guidance :
10561060                                noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
1057-                                 noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
1061+                                 noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (
1062+                                     noise_pred_text  -  noise_pred_uncond 
1063+                                 )
10581064                                timestep , _  =  timestep .chunk (2 )
10591065
10601066                                if  self .guidance_rescale  >  0 :
@@ -1082,7 +1088,9 @@ def __call__(
10821088                                prompt_embeds  =  callback_outputs .pop ("prompt_embeds" , prompt_embeds )
10831089
10841090                            # call the callback, if provided 
1085-                             if  i  ==  len (inner_timesteps ) -  1  or  ((i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  self .scheduler .order  ==  0 ):
1091+                             if  i  ==  len (inner_timesteps ) -  1  or  (
1092+                                 (i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  self .scheduler .order  ==  0 
1093+                             ):
10861094                                progress_bar .update ()
10871095
10881096                            if  XLA_AVAILABLE :
@@ -1096,13 +1104,15 @@ def __call__(
10961104                        self .transformer_spatial_patch_size ,
10971105                        self .transformer_temporal_patch_size ,
10981106                    )
1099-                      
1107+ 
11001108                    if  start_index  ==  0 :
11011109                        first_tile_out_latents  =  latent_chunk .clone ()
11021110                    else :
11031111                        # We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent 
1104-                         latent_chunk  =  latent_chunk [:, :, last_latent_tile_num_frames  +  1 :, :, :]
1105-                         latent_chunk  =  LTXLatentUpsamplePipeline .adain_filter_latent (latent_chunk , first_tile_out_latents , adain_factor )
1112+                         latent_chunk  =  latent_chunk [:, :, last_latent_tile_num_frames  +  1  :, :, :]
1113+                         latent_chunk  =  LTXLatentUpsamplePipeline .adain_filter_latent (
1114+                             latent_chunk , first_tile_out_latents , adain_factor 
1115+                         )
11061116
11071117                        alpha  =  torch .linspace (1 , 0 , temporal_overlap  +  1 , device = latent_chunk .device )[1 :- 1 ]
11081118                        alpha  =  alpha .view (1 , 1 , - 1 , 1 , 1 )
@@ -1111,14 +1121,17 @@ def __call__(
11111121                        t_minus_one  =  temporal_overlap  -  1 
11121122                        parts  =  [
11131123                            tile_out_latents [:, :, :- t_minus_one ],
1114-                             alpha  *  tile_out_latents [:, :, - t_minus_one :] +  (1  -  alpha ) *  latent_chunk [:, :, :t_minus_one ],
1124+                             alpha  *  tile_out_latents [:, :, - t_minus_one :]
1125+                             +  (1  -  alpha ) *  latent_chunk [:, :, :t_minus_one ],
11151126                            latent_chunk [:, :, t_minus_one :],
11161127                        ]
11171128                        latent_chunk  =  torch .cat (parts , dim = 2 )
11181129
11191130                    tile_out_latents  =  latent_chunk .clone ()
11201131
1121-                 tile_weights  =  self ._create_spatial_weights (tile_out_latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap )
1132+                 tile_weights  =  self ._create_spatial_weights (
1133+                     tile_out_latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap 
1134+                 )
11221135                final_latents [:, :, :, v_start :v_end , h_start :h_end ] +=  latent_chunk  *  tile_weights 
11231136                weights [:, :, :, v_start :v_end , h_start :h_end ] +=  tile_weights 
11241137
0 commit comments