@@ -974,9 +974,11 @@ def __call__(
974974                    latent_tile_num_frames  =  latent_chunk .shape [2 ]
975975
976976                    if  start_index  >  0 :
977-                         last_latent_chunk  =  self ._select_latents (tile_out_latents , - temporal_overlap , - 1 )
977+                         # last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1) 
978+                         last_latent_chunk  =  self ._select_latents (tile_out_latents , 0 , temporal_overlap  -  1 )
979+                         last_latent_chunk  =  torch .flip (last_latent_chunk , dims = [2 ])
978980                        last_latent_tile_num_frames  =  last_latent_chunk .shape [2 ]
979-                         latent_chunk  =  torch .cat ([last_latent_chunk ,  latent_chunk ], dim = 2 )
981+                         latent_chunk  =  torch .cat ([latent_chunk ,  last_latent_chunk ], dim = 2 )
980982                        total_latent_num_frames  =  last_latent_tile_num_frames  +  latent_tile_num_frames 
981983                        last_latent_chunk  =  self ._pack_latents (
982984                            last_latent_chunk ,
@@ -993,7 +995,9 @@ def __call__(
993995                            device = device ,
994996                        )
995997                        # conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength 
996-                         conditioning_mask [:, :last_latent_tile_num_frames ] =  1.0 
998+                         # conditioning_mask[:, :last_latent_tile_num_frames] = 1.0 
999+                         conditioning_mask [:, - last_latent_tile_num_frames :] =  temporal_overlap_cond_strength 
1000+                         # conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0 
9971001                    else :
9981002                        total_latent_num_frames  =  latent_tile_num_frames 
9991003
@@ -1051,14 +1055,14 @@ def __call__(
10511055                                torch .cat ([latent_chunk ] *  2 ) if  self .do_classifier_free_guidance  else  latent_chunk 
10521056                            )
10531057                            latent_model_input  =  latent_model_input .to (prompt_embeds .dtype )
1054-                              # Create timestep tensor that has prod(latent_model_input.shape) elements 
1058+ 
10551059                            if  start_index  ==  0 :
10561060                                timestep  =  t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 )
10571061                            else :
10581062                                timestep  =  t .view (1 , 1 ).expand ((latent_model_input .shape [:- 1 ])).clone ()
1059-                                 timestep [:, :last_latent_chunk_num_tokens ] =  0.0 
1060- 
1063+                                 timestep [:, - last_latent_chunk_num_tokens :] =  0.0 
10611064                            timestep  =  timestep .float ()
1065+ 
10621066                            # timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() 
10631067                            # if start_index > 0: 
10641068                            #     timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) 
@@ -1094,7 +1098,8 @@ def __call__(
10941098                                latent_chunk  =  denoised_latent_chunk 
10951099                            else :
10961100                                latent_chunk  =  torch .cat (
1097-                                     [last_latent_chunk , denoised_latent_chunk [:, last_latent_chunk_num_tokens :]], dim = 1 
1101+                                     [denoised_latent_chunk [:, :- last_latent_chunk_num_tokens ], last_latent_chunk ],
1102+                                     dim = 1 ,
10981103                                )
10991104                                # tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) 
11001105                                # latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk) 
@@ -1129,7 +1134,7 @@ def __call__(
11291134                    if  start_index  ==  0 :
11301135                        first_tile_out_latents  =  latent_chunk .clone ()
11311136                    else :
1132-                         latent_chunk  =  latent_chunk [:, :, last_latent_tile_num_frames : - 1 , :, :]
1137+                         latent_chunk  =  latent_chunk [:, :, 1 : - last_latent_tile_num_frames , :, :]
11331138                        latent_chunk  =  LTXLatentUpsamplePipeline .adain_filter_latent (
11341139                            latent_chunk , first_tile_out_latents , adain_factor 
11351140                        )
@@ -1140,10 +1145,10 @@ def __call__(
11401145                        # Combine samples 
11411146                        t_minus_one  =  temporal_overlap  -  1 
11421147                        parts  =  [
1143-                             tile_out_latents [:, :, :- t_minus_one ],
1144-                             alpha  *  tile_out_latents [:, :, - t_minus_one :]
1145-                             +  ( 1   -   alpha )  *  latent_chunk [:, :, :t_minus_one ],
1146-                             latent_chunk [:, :, t_minus_one :],
1148+                             latent_chunk [:, :, :- t_minus_one ],
1149+                             ( 1   -   alpha )  *  latent_chunk [:, :, - t_minus_one :]
1150+                             +  alpha  *  tile_out_latents [:, :, :t_minus_one ],
1151+                             tile_out_latents [:, :, t_minus_one :],
11471152                        ]
11481153                        latent_chunk  =  torch .cat (parts , dim = 2 )
11491154
@@ -1152,7 +1157,7 @@ def __call__(
11521157                tile_weights  =  self ._create_spatial_weights (
11531158                    tile_out_latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap 
11541159                )
1155-                 final_latents [:, :, :, v_start :v_end , h_start :h_end ] +=  latent_chunk  *  tile_weights 
1160+                 final_latents [:, :, :, v_start :v_end , h_start :h_end ] +=  tile_out_latents  *  tile_weights 
11561161                weights [:, :, :, v_start :v_end , h_start :h_end ] +=  tile_weights 
11571162
11581163        eps  =  1e-8 
0 commit comments