@@ -749,6 +749,7 @@ def __call__(
749749        max_sequence_length : int  =  256 ,
750750        temporal_tile_size : int  =  80 ,
751751        temporal_overlap : int  =  24 ,
752+         temporal_overlap_cond_strength : float  =  0.5 ,
752753        horizontal_tiles : int  =  1 ,
753754        vertical_tiles : int  =  1 ,
754755        spatial_overlap : int  =  1 ,
@@ -977,12 +978,21 @@ def __call__(
977978                        last_latent_tile_num_frames  =  last_latent_chunk .shape [2 ]
978979                        latent_chunk  =  torch .cat ([last_latent_chunk , latent_chunk ], dim = 2 )
979980                        total_latent_num_frames  =  last_latent_tile_num_frames  +  latent_tile_num_frames 
981+                         last_latent_chunk  =  self ._pack_latents (
982+                             last_latent_chunk ,
983+                             self .transformer_spatial_patch_size ,
984+                             self .transformer_temporal_patch_size ,
985+                         )
986+                         last_latent_chunk_num_tokens  =  last_latent_chunk .shape [1 ]
987+                         if  self .do_classifier_free_guidance :
988+                             last_latent_chunk  =  torch .cat ([last_latent_chunk , last_latent_chunk ], dim = 0 )
980989
981990                        conditioning_mask  =  torch .zeros (
982991                            (batch_size , total_latent_num_frames ),
983992                            dtype = torch .float32 ,
984993                            device = device ,
985994                        )
995+                         # conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength 
986996                        conditioning_mask [:, :last_latent_tile_num_frames ] =  1.0 
987997                    else :
988998                        total_latent_num_frames  =  latent_tile_num_frames 
@@ -1041,9 +1051,17 @@ def __call__(
10411051                                torch .cat ([latent_chunk ] *  2 ) if  self .do_classifier_free_guidance  else  latent_chunk 
10421052                            )
10431053                            latent_model_input  =  latent_model_input .to (prompt_embeds .dtype )
1044-                             timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1045-                             if  start_index  >  0 :
1046-                                 timestep  =  torch .min (timestep , (1  -  conditioning_mask_model_input ) *  1000.0 )
1054+                             # Create timestep tensor that has prod(latent_model_input.shape) elements 
1055+                             if  start_index  ==  0 :
1056+                                 timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 )
1057+                             else :
1058+                                 timestep  =  t .view (1 , 1 ).expand ((latent_model_input .shape [:- 1 ])).clone ()
1059+                                 timestep [:, :last_latent_chunk_num_tokens ] =  0.0 
1060+ 
1061+                             timestep  =  timestep .float ()
1062+                             # timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() 
1063+                             # if start_index > 0: 
1064+                             #     timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) 
10471065
10481066                            with  self .transformer .cache_context ("cond_uncond" ):
10491067                                noise_pred  =  self .transformer (
@@ -1075,8 +1093,11 @@ def __call__(
10751093                            if  start_index  ==  0 :
10761094                                latent_chunk  =  denoised_latent_chunk 
10771095                            else :
1078-                                 tokens_to_denoise_mask  =  (t  /  1000  -  1e-6  <  (1.0  -  conditioning_mask )).unsqueeze (- 1 )
1079-                                 latent_chunk  =  torch .where (tokens_to_denoise_mask , denoised_latent_chunk , latent_chunk )
1096+                                 latent_chunk  =  torch .cat (
1097+                                     [last_latent_chunk , denoised_latent_chunk [:, last_latent_chunk_num_tokens :]], dim = 1 
1098+                                 )
1099+                                 # tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) 
1100+                                 # latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk) 
10801101
10811102                            if  callback_on_step_end  is  not None :
10821103                                callback_kwargs  =  {}
@@ -1108,8 +1129,7 @@ def __call__(
11081129                    if  start_index  ==  0 :
11091130                        first_tile_out_latents  =  latent_chunk .clone ()
11101131                    else :
1111-                         # We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent 
1112-                         latent_chunk  =  latent_chunk [:, :, last_latent_tile_num_frames  +  1  :, :, :]
1132+                         latent_chunk  =  latent_chunk [:, :, last_latent_tile_num_frames :- 1 , :, :]
11131133                        latent_chunk  =  LTXLatentUpsamplePipeline .adain_filter_latent (
11141134                            latent_chunk , first_tile_out_latents , adain_factor 
11151135                        )
0 commit comments