Skip to content

Commit 26b694b

Browse files
yiyixuxusayakpaulyiyixuxu
committed
[scheduler] fix a bug in add_noise (#7386)
* fix * fix * add a tests * fix --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 84bc0e4 commit 26b694b

17 files changed

+84
-10
lines changed

src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,15 +528,12 @@ def check_inputs(
528528
f" {negative_prompt_embeds.shape}."
529529
)
530530

531-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
532531
def get_timesteps(self, num_inference_steps, strength, device):
533532
# get the original timestep using init_timestep
534533
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
535534

536535
t_start = max(num_inference_steps - init_timestep, 0)
537536
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
538-
if hasattr(self.scheduler, "set_begin_index"):
539-
self.scheduler.set_begin_index(t_start * self.scheduler.order)
540537

541538
return timesteps, num_inference_steps - t_start
542539

src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -716,15 +716,12 @@ def check_source_inputs(
716716
f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}."
717717
)
718718

719-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
720719
def get_timesteps(self, num_inference_steps, strength, device):
721720
# get the original timestep using init_timestep
722721
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
723722

724723
t_start = max(num_inference_steps - init_timestep, 0)
725724
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
726-
if hasattr(self.scheduler, "set_begin_index"):
727-
self.scheduler.set_begin_index(t_start * self.scheduler.order)
728725

729726
return timesteps, num_inference_steps - t_start
730727

src/diffusers/schedulers/scheduling_consistency_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,11 @@ def add_noise(
434434
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
435435
if self.begin_index is None:
436436
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
437+
elif self.step_index is not None:
438+
# add_noise is called after first denoising step (for inpainting)
439+
step_indices = [self.step_index] * timesteps.shape[0]
437440
else:
441+
# add noise is called bevore first denoising step to create inital latent(img2img)
438442
step_indices = [self.begin_index] * timesteps.shape[0]
439443

440444
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,10 +768,14 @@ def add_noise(
768768
schedule_timesteps = self.timesteps.to(original_samples.device)
769769
timesteps = timesteps.to(original_samples.device)
770770

771-
# begin_index is None when the scheduler is used for training
771+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
772772
if self.begin_index is None:
773773
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
774+
elif self.step_index is not None:
775+
# add_noise is called after first denoising step (for inpainting)
776+
step_indices = [self.step_index] * timesteps.shape[0]
774777
else:
778+
# add noise is called bevore first denoising step to create inital latent(img2img)
775779
step_indices = [self.begin_index] * timesteps.shape[0]
776780

777781
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,10 +1011,14 @@ def add_noise(
10111011
schedule_timesteps = self.timesteps.to(original_samples.device)
10121012
timesteps = timesteps.to(original_samples.device)
10131013

1014-
# begin_index is None when the scheduler is used for training
1014+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
10151015
if self.begin_index is None:
10161016
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1017+
elif self.step_index is not None:
1018+
# add_noise is called after first denoising step (for inpainting)
1019+
step_indices = [self.step_index] * timesteps.shape[0]
10171020
else:
1021+
# add noise is called bevore first denoising step to create inital latent(img2img)
10181022
step_indices = [self.begin_index] * timesteps.shape[0]
10191023

10201024
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,11 @@ def add_noise(
543543
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
544544
if self.begin_index is None:
545545
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
546+
elif self.step_index is not None:
547+
# add_noise is called after first denoising step (for inpainting)
548+
step_indices = [self.step_index] * timesteps.shape[0]
546549
else:
550+
# add noise is called bevore first denoising step to create inital latent(img2img)
547551
step_indices = [self.begin_index] * timesteps.shape[0]
548552

549553
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,10 +961,14 @@ def add_noise(
961961
schedule_timesteps = self.timesteps.to(original_samples.device)
962962
timesteps = timesteps.to(original_samples.device)
963963

964-
# begin_index is None when the scheduler is used for training
964+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
965965
if self.begin_index is None:
966966
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
967+
elif self.step_index is not None:
968+
# add_noise is called after first denoising step (for inpainting)
969+
step_indices = [self.step_index] * timesteps.shape[0]
967970
else:
971+
# add noise is called bevore first denoising step to create inital latent(img2img)
968972
step_indices = [self.begin_index] * timesteps.shape[0]
969973

970974
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,11 @@ def add_noise(
669669
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
670670
if self.begin_index is None:
671671
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
672+
elif self.step_index is not None:
673+
# add_noise is called after first denoising step (for inpainting)
674+
step_indices = [self.step_index] * timesteps.shape[0]
672675
else:
676+
# add noise is called bevore first denoising step to create inital latent(img2img)
673677
step_indices = [self.begin_index] * timesteps.shape[0]
674678

675679
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,11 @@ def add_noise(
367367
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
368368
if self.begin_index is None:
369369
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
370+
elif self.step_index is not None:
371+
# add_noise is called after first denoising step (for inpainting)
372+
step_indices = [self.step_index] * timesteps.shape[0]
370373
else:
374+
# add noise is called bevore first denoising step to create inital latent(img2img)
371375
step_indices = [self.begin_index] * timesteps.shape[0]
372376

373377
sigma = sigmas[step_indices].flatten()

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,11 @@ def add_noise(
467467
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
468468
if self.begin_index is None:
469469
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
470+
elif self.step_index is not None:
471+
# add_noise is called after first denoising step (for inpainting)
472+
step_indices = [self.step_index] * timesteps.shape[0]
470473
else:
474+
# add noise is called bevore first denoising step to create inital latent(img2img)
471475
step_indices = [self.begin_index] * timesteps.shape[0]
472476

473477
sigma = sigmas[step_indices].flatten()

0 commit comments

Comments
 (0)