Skip to content

Commit e981399

Browse files
committed
simplify
1 parent 27a4345 commit e981399

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,8 @@ def __call__(
856856

857857
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
858858
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
859+
if horizontal_tiles > 1 or vertical_tiles > 1:
860+
raise ValueError("Setting `horizontal_tiles` or `vertical_tiles` to a value greater than 0 is not supported yet.")
859861

860862
# 1. Check inputs. Raise error if not correct
861863
self.check_inputs(
@@ -1175,25 +1177,14 @@ def __call__(
11751177
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(latent_chunk, first_tile_out_latents, adain_factor)
11761178

11771179
alpha = torch.linspace(1, 0, temporal_overlap + 1, device=latent_chunk.device)[1:-1]
1178-
shape = [1] * latent_chunk.dim()
1179-
shape[2] = alpha.size(0)
1180-
alpha = alpha.reshape(shape)
1181-
1182-
slice_all = [slice(None)] * latent_chunk.dim()
1183-
slice_overlap1 = slice_all.copy()
1184-
slice_overlap1[2] = slice(-(temporal_overlap - 1), None)
1185-
slice_overlap2 = slice_all.copy()
1186-
slice_overlap2[2] = slice(0, temporal_overlap - 1)
1187-
slice_rest1 = slice_all.copy()
1188-
slice_rest1[2] = slice(None, -(temporal_overlap - 1))
1189-
slice_rest2 = slice_all.copy()
1190-
slice_rest2[2] = slice(temporal_overlap - 1, None)
1180+
alpha = alpha.view(1, 1, -1, 1, 1)
11911181

11921182
# Combine samples
1183+
t_minus_one = temporal_overlap - 1
11931184
parts = [
1194-
tile_out_latents[tuple(slice_rest1)],
1195-
alpha * tile_out_latents[tuple(slice_overlap1)] + (1 - alpha) * latent_chunk[tuple(slice_overlap2)],
1196-
latent_chunk[tuple(slice_rest2)],
1185+
tile_out_latents[:, :, :-t_minus_one],
1186+
alpha * tile_out_latents[:, :, -t_minus_one:] + (1 - alpha) * latent_chunk[:, :, :t_minus_one],
1187+
latent_chunk[:, :, t_minus_one:],
11971188
]
11981189
latent_chunk = torch.cat(parts, dim=2)
11991190

@@ -1205,6 +1196,7 @@ def __call__(
12051196

12061197
eps = 1e-8
12071198
latents = final_latents / (weights + eps)
1199+
latents = LTXLatentUpsamplePipeline.tone_map_latents(latents, tone_map_compression_ratio)
12081200

12091201
if output_type == "latent":
12101202
video = latents

src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor,
122122
result = torch.lerp(latents, result, factor)
123123
return result
124124

125-
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
125+
@staticmethod
126+
def tone_map_latents(latents: torch.Tensor, compression: float) -> torch.Tensor:
126127
"""
127128
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
128129
smooth way using a sigmoid-based compression.

0 commit comments

Comments
 (0)