Skip to content

Commit 83bda52

Browse files
committed
Revert "try generating in reverse like... like what seems to be done in original codebase"
This reverts commit 0e97669.
1 parent 0e97669 commit 83bda52

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -974,11 +974,9 @@ def __call__(
974974
latent_tile_num_frames = latent_chunk.shape[2]
975975

976976
if start_index > 0:
977-
# last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
978-
last_latent_chunk = self._select_latents(tile_out_latents, 0, temporal_overlap - 1)
979-
last_latent_chunk = torch.flip(last_latent_chunk, dims=[2])
977+
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
980978
last_latent_tile_num_frames = last_latent_chunk.shape[2]
981-
latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2)
979+
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
982980
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
983981
last_latent_chunk = self._pack_latents(
984982
last_latent_chunk,
@@ -995,9 +993,7 @@ def __call__(
995993
device=device,
996994
)
997995
# conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength
998-
# conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
999-
conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength
1000-
# conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0
996+
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
1001997
else:
1002998
total_latent_num_frames = latent_tile_num_frames
1003999

@@ -1055,14 +1051,14 @@ def __call__(
10551051
torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
10561052
)
10571053
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1058-
1054+
# Create timestep tensor that has prod(latent_model_input.shape) elements
10591055
if start_index == 0:
10601056
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1)
10611057
else:
10621058
timestep = t.view(1, 1).expand((latent_model_input.shape[:-1])).clone()
1063-
timestep[:, -last_latent_chunk_num_tokens:] = 0.0
1064-
timestep = timestep.float()
1059+
timestep[:, :last_latent_chunk_num_tokens] = 0.0
10651060

1061+
timestep = timestep.float()
10661062
# timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
10671063
# if start_index > 0:
10681064
# timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
@@ -1098,8 +1094,7 @@ def __call__(
10981094
latent_chunk = denoised_latent_chunk
10991095
else:
11001096
latent_chunk = torch.cat(
1101-
[denoised_latent_chunk[:, :-last_latent_chunk_num_tokens], last_latent_chunk],
1102-
dim=1,
1097+
[last_latent_chunk, denoised_latent_chunk[:, last_latent_chunk_num_tokens:]], dim=1
11031098
)
11041099
# tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
11051100
# latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
@@ -1134,7 +1129,7 @@ def __call__(
11341129
if start_index == 0:
11351130
first_tile_out_latents = latent_chunk.clone()
11361131
else:
1137-
latent_chunk = latent_chunk[:, :, 1:-last_latent_tile_num_frames, :, :]
1132+
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames:-1, :, :]
11381133
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(
11391134
latent_chunk, first_tile_out_latents, adain_factor
11401135
)
@@ -1145,10 +1140,10 @@ def __call__(
11451140
# Combine samples
11461141
t_minus_one = temporal_overlap - 1
11471142
parts = [
1148-
latent_chunk[:, :, :-t_minus_one],
1149-
(1 - alpha) * latent_chunk[:, :, -t_minus_one:]
1150-
+ alpha * tile_out_latents[:, :, :t_minus_one],
1151-
tile_out_latents[:, :, t_minus_one:],
1143+
tile_out_latents[:, :, :-t_minus_one],
1144+
alpha * tile_out_latents[:, :, -t_minus_one:]
1145+
+ (1 - alpha) * latent_chunk[:, :, :t_minus_one],
1146+
latent_chunk[:, :, t_minus_one:],
11521147
]
11531148
latent_chunk = torch.cat(parts, dim=2)
11541149

@@ -1157,7 +1152,7 @@ def __call__(
11571152
tile_weights = self._create_spatial_weights(
11581153
tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap
11591154
)
1160-
final_latents[:, :, :, v_start:v_end, h_start:h_end] += tile_out_latents * tile_weights
1155+
final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights
11611156
weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights
11621157

11631158
eps = 1e-8

0 commit comments

Comments
 (0)