@@ -974,11 +974,9 @@ def __call__(
974974                    latent_tile_num_frames  =  latent_chunk .shape [2 ]
975975
976976                    if  start_index  >  0 :
977-                         # last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) 
978-                         last_latent_chunk  =  self ._select_latents (tile_out_latents , 0 , temporal_overlap  -  1 )
979-                         last_latent_chunk  =  torch .flip (last_latent_chunk , dims = [2 ])
977+                         last_latent_chunk  =  self ._select_latents (tile_out_latents , - temporal_overlap , - 1 )
980978                        last_latent_tile_num_frames  =  last_latent_chunk .shape [2 ]
981-                         latent_chunk  =  torch .cat ([latent_chunk ,  last_latent_chunk ], dim = 2 )
979+                         latent_chunk  =  torch .cat ([last_latent_chunk ,  latent_chunk ], dim = 2 )
982980                        total_latent_num_frames  =  last_latent_tile_num_frames  +  latent_tile_num_frames 
983981                        last_latent_chunk  =  self ._pack_latents (
984982                            last_latent_chunk ,
@@ -995,9 +993,7 @@ def __call__(
995993                            device = device ,
996994                        )
997995                        # conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength 
998-                         # conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 
999-                         conditioning_mask [:, - last_latent_tile_num_frames :] =  temporal_overlap_cond_strength 
1000-                         # conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0 
996+                         conditioning_mask [:, :last_latent_tile_num_frames ] =  1.0 
1001997                    else :
1002998                        total_latent_num_frames  =  latent_tile_num_frames 
1003999
@@ -1055,14 +1051,14 @@ def __call__(
10551051                                torch .cat ([latent_chunk ] *  2 ) if  self .do_classifier_free_guidance  else  latent_chunk 
10561052                            )
10571053                            latent_model_input  =  latent_model_input .to (prompt_embeds .dtype )
1058- 
1054+                              # Create timestep tensor that has prod(latent_model_input.shape) elements 
10591055                            if  start_index  ==  0 :
10601056                                timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 )
10611057                            else :
10621058                                timestep  =  t .view (1 , 1 ).expand ((latent_model_input .shape [:- 1 ])).clone ()
1063-                                 timestep [:, - last_latent_chunk_num_tokens :] =  0.0 
1064-                             timestep  =  timestep .float ()
1059+                                 timestep [:, :last_latent_chunk_num_tokens ] =  0.0 
10651060
1061+                             timestep  =  timestep .float ()
10661062                            # timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() 
10671063                            # if start_index > 0: 
10681064                            #     timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) 
@@ -1098,8 +1094,7 @@ def __call__(
10981094                                latent_chunk  =  denoised_latent_chunk 
10991095                            else :
11001096                                latent_chunk  =  torch .cat (
1101-                                     [denoised_latent_chunk [:, :- last_latent_chunk_num_tokens ], last_latent_chunk ],
1102-                                     dim = 1 ,
1097+                                     [last_latent_chunk , denoised_latent_chunk [:, last_latent_chunk_num_tokens :]], dim = 1 
11031098                                )
11041099                                # tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) 
11051100                                # latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk) 
@@ -1134,7 +1129,7 @@ def __call__(
11341129                    if  start_index  ==  0 :
11351130                        first_tile_out_latents  =  latent_chunk .clone ()
11361131                    else :
1137-                         latent_chunk  =  latent_chunk [:, :, 1 : - last_latent_tile_num_frames , :, :]
1132+                         latent_chunk  =  latent_chunk [:, :, last_latent_tile_num_frames : - 1 , :, :]
11381133                        latent_chunk  =  LTXLatentUpsamplePipeline .adain_filter_latent (
11391134                            latent_chunk , first_tile_out_latents , adain_factor 
11401135                        )
@@ -1145,10 +1140,10 @@ def __call__(
11451140                        # Combine samples 
11461141                        t_minus_one  =  temporal_overlap  -  1 
11471142                        parts  =  [
1148-                             latent_chunk [:, :, :- t_minus_one ],
1149-                             ( 1   -   alpha )  *  latent_chunk [:, :, - t_minus_one :]
1150-                             +  alpha  *  tile_out_latents [:, :, :t_minus_one ],
1151-                             tile_out_latents [:, :, t_minus_one :],
1143+                             tile_out_latents [:, :, :- t_minus_one ],
1144+                             alpha  *  tile_out_latents [:, :, - t_minus_one :]
1145+                             +  ( 1   -   alpha )  *  latent_chunk [:, :, :t_minus_one ],
1146+                             latent_chunk [:, :, t_minus_one :],
11521147                        ]
11531148                        latent_chunk  =  torch .cat (parts , dim = 2 )
11541149
@@ -1157,7 +1152,7 @@ def __call__(
11571152                tile_weights  =  self ._create_spatial_weights (
11581153                    tile_out_latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap 
11591154                )
1160-                 final_latents [:, :, :, v_start :v_end , h_start :h_end ] +=  tile_out_latents  *  tile_weights 
1155+                 final_latents [:, :, :, v_start :v_end , h_start :h_end ] +=  latent_chunk  *  tile_weights 
11611156                weights [:, :, :, v_start :v_end , h_start :h_end ] +=  tile_weights 
11621157
11631158        eps  =  1e-8 
0 commit comments