Skip to content

Commit 47bf390

Browse files
committed
refactor 1
1 parent e981399 commit 47bf390

File tree

1 file changed

+24
-31
lines changed

1 file changed

+24
-31
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,30 @@ def __call__(
973973
):
974974
latent_chunk = self._select_latents(tile_latents, start_index, min(end_index - 1, tile_latents.shape[2] - 1))
975975
latent_tile_num_frames = latent_chunk.shape[2]
976+
977+
if start_index > 0:
978+
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
979+
last_latent_tile_num_frames = last_latent_chunk.shape[2]
980+
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
981+
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
982+
else:
983+
total_latent_num_frames = latent_tile_num_frames
984+
985+
latent_chunk = self._pack_latents(
986+
latent_chunk,
987+
self.transformer_spatial_patch_size,
988+
self.transformer_temporal_patch_size,
989+
)
990+
991+
video_ids = self._prepare_video_ids(
992+
batch_size,
993+
total_latent_num_frames,
994+
latent_tile_height,
995+
latent_tile_width,
996+
patch_size_t=self.transformer_temporal_patch_size,
997+
patch_size=self.transformer_spatial_patch_size,
998+
device=device,
999+
)
9761000

9771001
# Set timesteps
9781002
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -981,17 +1005,6 @@ def __call__(
9811005
self._num_timesteps = len(inner_timesteps)
9821006

9831007
if start_index == 0:
984-
latent_chunk = self._pack_latents(latent_chunk, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
985-
986-
video_ids = self._prepare_video_ids(
987-
batch_size,
988-
latent_tile_num_frames,
989-
latent_tile_height,
990-
latent_tile_width,
991-
patch_size_t=self.transformer_temporal_patch_size,
992-
patch_size=self.transformer_spatial_patch_size,
993-
device=device,
994-
)
9951008
video_ids = self._scale_video_ids(
9961009
video_ids,
9971010
scale_factor=self.vae_spatial_compression_ratio,
@@ -1066,26 +1079,6 @@ def __call__(
10661079
)
10671080
first_tile_out_latents = tile_out_latents.clone()
10681081
else:
1069-
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
1070-
last_latent_tile_num_frames = last_latent_chunk.shape[2]
1071-
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
1072-
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
1073-
latent_chunk = self._pack_latents(
1074-
latent_chunk,
1075-
self.transformer_spatial_patch_size,
1076-
self.transformer_temporal_patch_size,
1077-
)
1078-
1079-
video_ids = self._prepare_video_ids(
1080-
batch_size,
1081-
total_latent_num_frames,
1082-
latent_tile_height,
1083-
latent_tile_width,
1084-
patch_size_t=self.transformer_temporal_patch_size,
1085-
patch_size=self.transformer_spatial_patch_size,
1086-
device=device,
1087-
)
1088-
10891082
conditioning_mask = torch.zeros(
10901083
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
10911084
)

0 commit comments

Comments
 (0)