Skip to content

Commit 84c17be

Browse files
committed
refactor 2
1 parent 47bf390 commit 84c17be

File tree

1 file changed

+20
-30
lines changed

1 file changed

+20
-30
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,11 @@ def __call__(
979979
last_latent_tile_num_frames = last_latent_chunk.shape[2]
980980
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
981981
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
982+
983+
conditioning_mask = torch.zeros(
984+
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
985+
)
986+
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
982987
else:
983988
total_latent_num_frames = latent_tile_num_frames
984989

@@ -998,25 +1003,28 @@ def __call__(
9981003
device=device,
9991004
)
10001005

1006+
if start_index > 0:
1007+
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
1008+
1009+
video_ids = self._scale_video_ids(
1010+
video_ids,
1011+
scale_factor=self.vae_spatial_compression_ratio,
1012+
scale_factor_t=self.vae_temporal_compression_ratio,
1013+
frame_index=0,
1014+
device=device
1015+
)
1016+
video_ids = video_ids.float()
1017+
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
1018+
if self.do_classifier_free_guidance:
1019+
video_ids = torch.cat([video_ids, video_ids], dim=0)
1020+
10011021
# Set timesteps
10021022
inner_timesteps, inner_num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
10031023
sigmas = self.scheduler.sigmas
10041024
num_warmup_steps = max(len(inner_timesteps) - inner_num_inference_steps * self.scheduler.order, 0)
10051025
self._num_timesteps = len(inner_timesteps)
10061026

10071027
if start_index == 0:
1008-
video_ids = self._scale_video_ids(
1009-
video_ids,
1010-
scale_factor=self.vae_spatial_compression_ratio,
1011-
scale_factor_t=self.vae_temporal_compression_ratio,
1012-
frame_index=0,
1013-
device=device
1014-
)
1015-
video_ids = video_ids.float()
1016-
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
1017-
if self.do_classifier_free_guidance:
1018-
video_ids = torch.cat([video_ids, video_ids], dim=0)
1019-
10201028
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
10211029
for i, t in enumerate(inner_timesteps):
10221030
if self.interrupt:
@@ -1079,24 +1087,6 @@ def __call__(
10791087
)
10801088
first_tile_out_latents = tile_out_latents.clone()
10811089
else:
1082-
conditioning_mask = torch.zeros(
1083-
(batch_size, total_latent_num_frames), dtype=torch.float32, device=device,
1084-
)
1085-
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
1086-
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
1087-
1088-
video_ids = self._scale_video_ids(
1089-
video_ids,
1090-
scale_factor=self.vae_spatial_compression_ratio,
1091-
scale_factor_t=self.vae_temporal_compression_ratio,
1092-
frame_index=0,
1093-
device=device
1094-
)
1095-
video_ids = video_ids.float()
1096-
video_ids[:, 0] = video_ids[:, 0] * (1.0 / frame_rate)
1097-
if self.do_classifier_free_guidance:
1098-
video_ids = torch.cat([video_ids, video_ids], dim=0)
1099-
11001090
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
11011091
for i, t in enumerate(inner_timesteps):
11021092
if self.interrupt:

0 commit comments

Comments
 (0)