Skip to content

Commit 565fa0c

Browse files
committed
fix cfg cpu gpu delay
1 parent b11cc60 commit 565fa0c

File tree

1 file changed

+6
-19
lines changed

1 file changed

+6
-19
lines changed

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -698,30 +698,17 @@ def __call__(
698698
# 6. Denoising loop
699699
with self.progress_bar(total=num_inference_steps) as progress_bar:
700700
for i, t in enumerate(timesteps):
701+
# compute whether apply classifier-free truncation on this timestep
702+
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
703+
701704
# expand the latents if we are doing classifier free guidance
702705
latent_model_input = (
703706
torch.cat([latents] * 2)
704-
if do_classifier_free_guidance
705-
and 1 - t.item() / self.scheduler.config.num_train_timesteps < cfg_trunc_ratio
707+
if do_classifier_free_guidance and not do_classifier_free_truncation
706708
else latents
707709
)
710+
708711
current_timestep = t
709-
if not torch.is_tensor(current_timestep):
710-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
711-
# This would be a good case for the `match` statement (Python 3.10+)
712-
is_mps = latent_model_input.device.type == "mps"
713-
if isinstance(current_timestep, float):
714-
dtype = torch.float32 if is_mps else torch.float64
715-
else:
716-
dtype = torch.int32 if is_mps else torch.int64
717-
current_timestep = torch.tensor(
718-
[current_timestep],
719-
dtype=dtype,
720-
device=latent_model_input.device,
721-
)
722-
elif len(current_timestep.shape) == 0:
723-
current_timestep = current_timestep[None].to(latent_model_input.device)
724-
725712
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
726713
current_timestep = current_timestep.expand(latent_model_input.shape[0])
727714
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
@@ -736,7 +723,7 @@ def __call__(
736723
)[0]
737724

738725
# perform normalization-based guidance scale on a truncated timestep interval
739-
if do_classifier_free_guidance and current_timestep[0] < cfg_trunc_ratio:
726+
if self.do_classifier_free_guidance and not do_classifier_free_truncation:
740727
noise_pred_cond, noise_pred_uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
741728
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
742729
# apply normalization after classifier-free guidance

0 commit comments

Comments
 (0)