@@ -856,6 +856,8 @@ def __call__(
856856
857857 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
858858 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
859+ if horizontal_tiles > 1 or vertical_tiles > 1 :
860+ raise ValueError ("Setting `horizontal_tiles` or `vertical_tiles` to a value greater than 0 is not supported yet." )
859861
860862 # 1. Check inputs. Raise error if not correct
861863 self .check_inputs (
@@ -1175,25 +1177,14 @@ def __call__(
11751177 latent_chunk = LTXLatentUpsamplePipeline .adain_filter_latent (latent_chunk , first_tile_out_latents , adain_factor )
11761178
11771179 alpha = torch .linspace (1 , 0 , temporal_overlap + 1 , device = latent_chunk .device )[1 :- 1 ]
1178- shape = [1 ] * latent_chunk .dim ()
1179- shape [2 ] = alpha .size (0 )
1180- alpha = alpha .reshape (shape )
1181-
1182- slice_all = [slice (None )] * latent_chunk .dim ()
1183- slice_overlap1 = slice_all .copy ()
1184- slice_overlap1 [2 ] = slice (- (temporal_overlap - 1 ), None )
1185- slice_overlap2 = slice_all .copy ()
1186- slice_overlap2 [2 ] = slice (0 , temporal_overlap - 1 )
1187- slice_rest1 = slice_all .copy ()
1188- slice_rest1 [2 ] = slice (None , - (temporal_overlap - 1 ))
1189- slice_rest2 = slice_all .copy ()
1190- slice_rest2 [2 ] = slice (temporal_overlap - 1 , None )
1180+ alpha = alpha .view (1 , 1 , - 1 , 1 , 1 )
11911181
11921182 # Combine samples
1183+ t_minus_one = temporal_overlap - 1
11931184 parts = [
1194- tile_out_latents [tuple ( slice_rest1 ) ],
1195- alpha * tile_out_latents [tuple ( slice_overlap1 ) ] + (1 - alpha ) * latent_chunk [tuple ( slice_overlap2 ) ],
1196- latent_chunk [tuple ( slice_rest2 ) ],
1185+ tile_out_latents [:, :, : - t_minus_one ],
1186+ alpha * tile_out_latents [:, :, - t_minus_one : ] + (1 - alpha ) * latent_chunk [:, :, : t_minus_one ],
1187+ latent_chunk [:, :, t_minus_one : ],
11971188 ]
11981189 latent_chunk = torch .cat (parts , dim = 2 )
11991190
@@ -1205,6 +1196,7 @@ def __call__(
12051196
12061197 eps = 1e-8
12071198 latents = final_latents / (weights + eps )
1199+ latents = LTXLatentUpsamplePipeline .tone_map_latents (latents , tone_map_compression_ratio )
12081200
12091201 if output_type == "latent" :
12101202 video = latents
0 commit comments