Skip to content

Commit 3de88be

Browse files
committed
refactor 3
1 parent 84c17be commit 3de88be

File tree

1 file changed

+70
-125
lines changed

1 file changed

+70
-125
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 70 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,11 @@ def __call__(
10051005

10061006
if start_index > 0:
10071007
conditioning_mask = conditioning_mask.gather(1, video_ids[:, 0])
1008+
conditioning_mask_model_input = (
1009+
torch.cat([conditioning_mask, conditioning_mask])
1010+
if self.do_classifier_free_guidance
1011+
else conditioning_mask
1012+
)
10081013

10091014
video_ids = self._scale_video_ids(
10101015
video_ids,
@@ -1024,137 +1029,77 @@ def __call__(
10241029
num_warmup_steps = max(len(inner_timesteps) - inner_num_inference_steps * self.scheduler.order, 0)
10251030
self._num_timesteps = len(inner_timesteps)
10261031

1027-
if start_index == 0:
1028-
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
1029-
for i, t in enumerate(inner_timesteps):
1030-
if self.interrupt:
1031-
continue
1032-
1033-
self._current_timestep = t
1034-
latent_model_input = torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
1035-
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1036-
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1037-
1038-
with self.transformer.cache_context("cond_uncond"):
1039-
noise_pred = self.transformer(
1040-
hidden_states=latent_model_input,
1041-
encoder_hidden_states=prompt_embeds,
1042-
timestep=timestep,
1043-
encoder_attention_mask=prompt_attention_mask,
1044-
video_coords=video_ids,
1045-
attention_kwargs=attention_kwargs,
1046-
return_dict=False,
1047-
)[0]
1048-
1049-
if self.do_classifier_free_guidance:
1050-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1051-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1052-
timestep, _ = timestep.chunk(2)
1053-
1054-
if self.guidance_rescale > 0:
1055-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1056-
noise_pred = rescale_noise_cfg(
1057-
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1058-
)
1059-
1060-
latent_chunk = self.scheduler.step(
1061-
-noise_pred, t, latent_chunk, per_token_timesteps=timestep, return_dict=False
1062-
)[0]
1032+
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
1033+
for i, t in enumerate(inner_timesteps):
1034+
if self.interrupt:
1035+
continue
10631036

1064-
if callback_on_step_end is not None:
1065-
callback_kwargs = {}
1066-
for k in callback_on_step_end_tensor_inputs:
1067-
callback_kwargs[k] = locals()[k]
1068-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1069-
1070-
latent_chunk = callback_outputs.pop("latents", latent_chunk)
1071-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1072-
1073-
# call the callback, if provided
1074-
if i == len(inner_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1075-
progress_bar.update()
1076-
1077-
if XLA_AVAILABLE:
1078-
xm.mark_step()
1079-
1080-
tile_out_latents = self._unpack_latents(
1081-
latent_chunk,
1082-
latent_tile_num_frames,
1083-
latent_tile_height,
1084-
latent_tile_width,
1085-
self.transformer_spatial_patch_size,
1086-
self.transformer_temporal_patch_size,
1087-
)
1088-
first_tile_out_latents = tile_out_latents.clone()
1089-
else:
1090-
with self.progress_bar(total=inner_num_inference_steps) as progress_bar:
1091-
for i, t in enumerate(inner_timesteps):
1092-
if self.interrupt:
1093-
continue
1094-
1095-
self._current_timestep = t
1096-
latent_model_input = torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
1097-
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1098-
conditioning_mask_model_input = (
1099-
torch.cat([conditioning_mask, conditioning_mask])
1100-
if self.do_classifier_free_guidance
1101-
else conditioning_mask
1102-
)
1103-
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1037+
self._current_timestep = t
1038+
latent_model_input = torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
1039+
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1040+
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1041+
if start_index > 0:
11041042
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
11051043

1106-
with self.transformer.cache_context("cond_uncond"):
1107-
noise_pred = self.transformer(
1108-
hidden_states=latent_model_input,
1109-
encoder_hidden_states=prompt_embeds,
1110-
timestep=timestep,
1111-
encoder_attention_mask=prompt_attention_mask,
1112-
video_coords=video_ids,
1113-
attention_kwargs=attention_kwargs,
1114-
return_dict=False,
1115-
)[0]
1116-
1117-
if self.do_classifier_free_guidance:
1118-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1119-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1120-
timestep, _ = timestep.chunk(2)
1121-
1122-
if self.guidance_rescale > 0:
1123-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1124-
noise_pred = rescale_noise_cfg(
1125-
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1126-
)
1127-
1128-
denoised_latent_chunk = self.scheduler.step(
1129-
-noise_pred, t, latent_chunk, per_token_timesteps=timestep, return_dict=False
1044+
with self.transformer.cache_context("cond_uncond"):
1045+
noise_pred = self.transformer(
1046+
hidden_states=latent_model_input,
1047+
encoder_hidden_states=prompt_embeds,
1048+
timestep=timestep,
1049+
encoder_attention_mask=prompt_attention_mask,
1050+
video_coords=video_ids,
1051+
attention_kwargs=attention_kwargs,
1052+
return_dict=False,
11301053
)[0]
1054+
1055+
if self.do_classifier_free_guidance:
1056+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1057+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1058+
timestep, _ = timestep.chunk(2)
1059+
1060+
if self.guidance_rescale > 0:
1061+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1062+
noise_pred = rescale_noise_cfg(
1063+
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1064+
)
1065+
1066+
denoised_latent_chunk = self.scheduler.step(
1067+
-noise_pred, t, latent_chunk, per_token_timesteps=timestep, return_dict=False
1068+
)[0]
1069+
if start_index == 0:
1070+
latent_chunk = denoised_latent_chunk
1071+
else:
11311072
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
11321073
latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
11331074

1134-
if callback_on_step_end is not None:
1135-
callback_kwargs = {}
1136-
for k in callback_on_step_end_tensor_inputs:
1137-
callback_kwargs[k] = locals()[k]
1138-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1139-
1140-
latent_chunk = callback_outputs.pop("latents", latent_chunk)
1141-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1142-
1143-
# call the callback, if provided
1144-
if i == len(inner_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1145-
progress_bar.update()
1146-
1147-
if XLA_AVAILABLE:
1148-
xm.mark_step()
1149-
1150-
latent_chunk = self._unpack_latents(
1151-
latent_chunk,
1152-
total_latent_num_frames,
1153-
latent_tile_height,
1154-
latent_tile_width,
1155-
self.transformer_spatial_patch_size,
1156-
self.transformer_temporal_patch_size,
1157-
)
1075+
if callback_on_step_end is not None:
1076+
callback_kwargs = {}
1077+
for k in callback_on_step_end_tensor_inputs:
1078+
callback_kwargs[k] = locals()[k]
1079+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1080+
1081+
latent_chunk = callback_outputs.pop("latents", latent_chunk)
1082+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1083+
1084+
# call the callback, if provided
1085+
if i == len(inner_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1086+
progress_bar.update()
1087+
1088+
if XLA_AVAILABLE:
1089+
xm.mark_step()
1090+
1091+
latent_chunk = self._unpack_latents(
1092+
latent_chunk,
1093+
total_latent_num_frames,
1094+
latent_tile_height,
1095+
latent_tile_width,
1096+
self.transformer_spatial_patch_size,
1097+
self.transformer_temporal_patch_size,
1098+
)
1099+
1100+
if start_index == 0:
1101+
first_tile_out_latents = latent_chunk.clone()
1102+
else:
11581103
# We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
11591104
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames + 1:, :, :]
11601105
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(latent_chunk, first_tile_out_latents, adain_factor)
@@ -1171,7 +1116,7 @@ def __call__(
11711116
]
11721117
latent_chunk = torch.cat(parts, dim=2)
11731118

1174-
tile_out_latents = latent_chunk.clone()
1119+
tile_out_latents = latent_chunk.clone()
11751120

11761121
tile_weights = self._create_spatial_weights(tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap)
11771122
final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights

0 commit comments

Comments
 (0)