Skip to content

Commit 485b8bb

Browse files
authored
refactor get_timesteps for SDXL img2img + add set_begin_index (#9375)
* refator + add begin_index * add kolors img2img to doc
1 parent d08ad65 commit 485b8bb

File tree

7 files changed

+86
-60
lines changed

7 files changed

+86
-60
lines changed

docs/source/en/api/pipelines/kolors.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,11 @@ image.save("kolors_ipa_sample.png")
105105

106106
- all
107107
- __call__
108+
109+
## KolorsImg2ImgPipeline
110+
111+
[[autodoc]] KolorsImg2ImgPipeline
112+
113+
- all
114+
- __call__
115+

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,22 +1024,24 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
10241024
if denoising_start is None:
10251025
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
10261026
t_start = max(num_inference_steps - init_timestep, 0)
1027-
else:
1028-
t_start = 0
10291027

1030-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
1028+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
1029+
if hasattr(self.scheduler, "set_begin_index"):
1030+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
1031+
1032+
return timesteps, num_inference_steps - t_start
10311033

1032-
# Strength is irrelevant if we directly request a timestep to start at;
1033-
# that is, strength is determined by the denoising_start instead.
1034-
if denoising_start is not None:
1034+
else:
1035+
# Strength is irrelevant if we directly request a timestep to start at;
1036+
# that is, strength is determined by the denoising_start instead.
10351037
discrete_timestep_cutoff = int(
10361038
round(
10371039
self.scheduler.config.num_train_timesteps
10381040
- (denoising_start * self.scheduler.config.num_train_timesteps)
10391041
)
10401042
)
10411043

1042-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
1044+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
10431045
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
10441046
# if the scheduler is a 2nd order scheduler we might have to do +1
10451047
# because `num_inference_steps` might be even given that every timestep
@@ -1050,11 +1052,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
10501052
num_inference_steps = num_inference_steps + 1
10511053

10521054
# because t_n+1 >= t_n, we slice the timesteps starting from the end
1053-
timesteps = timesteps[-num_inference_steps:]
1055+
t_start = len(self.scheduler.timesteps) - num_inference_steps
1056+
timesteps = self.scheduler.timesteps[t_start:]
1057+
if hasattr(self.scheduler, "set_begin_index"):
1058+
self.scheduler.set_begin_index(t_start)
10541059
return timesteps, num_inference_steps
10551060

1056-
return timesteps, num_inference_steps - t_start
1057-
10581061
def _get_add_time_ids(
10591062
self,
10601063
original_size,

src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -564,22 +564,24 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
564564
if denoising_start is None:
565565
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
566566
t_start = max(num_inference_steps - init_timestep, 0)
567-
else:
568-
t_start = 0
569567

570-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
568+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
569+
if hasattr(self.scheduler, "set_begin_index"):
570+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
571+
572+
return timesteps, num_inference_steps - t_start
571573

572-
# Strength is irrelevant if we directly request a timestep to start at;
573-
# that is, strength is determined by the denoising_start instead.
574-
if denoising_start is not None:
574+
else:
575+
# Strength is irrelevant if we directly request a timestep to start at;
576+
# that is, strength is determined by the denoising_start instead.
575577
discrete_timestep_cutoff = int(
576578
round(
577579
self.scheduler.config.num_train_timesteps
578580
- (denoising_start * self.scheduler.config.num_train_timesteps)
579581
)
580582
)
581583

582-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
584+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
583585
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
584586
# if the scheduler is a 2nd order scheduler we might have to do +1
585587
# because `num_inference_steps` might be even given that every timestep
@@ -590,11 +592,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
590592
num_inference_steps = num_inference_steps + 1
591593

592594
# because t_n+1 >= t_n, we slice the timesteps starting from the end
593-
timesteps = timesteps[-num_inference_steps:]
595+
t_start = len(self.scheduler.timesteps) - num_inference_steps
596+
timesteps = self.scheduler.timesteps[t_start:]
597+
if hasattr(self.scheduler, "set_begin_index"):
598+
self.scheduler.set_begin_index(t_start)
594599
return timesteps, num_inference_steps
595600

596-
return timesteps, num_inference_steps - t_start
597-
598601
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
599602
def prepare_latents(
600603
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -648,22 +648,24 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
648648
if denoising_start is None:
649649
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
650650
t_start = max(num_inference_steps - init_timestep, 0)
651-
else:
652-
t_start = 0
653651

654-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
652+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
653+
if hasattr(self.scheduler, "set_begin_index"):
654+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
655+
656+
return timesteps, num_inference_steps - t_start
655657

656-
# Strength is irrelevant if we directly request a timestep to start at;
657-
# that is, strength is determined by the denoising_start instead.
658-
if denoising_start is not None:
658+
else:
659+
# Strength is irrelevant if we directly request a timestep to start at;
660+
# that is, strength is determined by the denoising_start instead.
659661
discrete_timestep_cutoff = int(
660662
round(
661663
self.scheduler.config.num_train_timesteps
662664
- (denoising_start * self.scheduler.config.num_train_timesteps)
663665
)
664666
)
665667

666-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
668+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
667669
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
668670
# if the scheduler is a 2nd order scheduler we might have to do +1
669671
# because `num_inference_steps` might be even given that every timestep
@@ -674,11 +676,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
674676
num_inference_steps = num_inference_steps + 1
675677

676678
# because t_n+1 >= t_n, we slice the timesteps starting from the end
677-
timesteps = timesteps[-num_inference_steps:]
679+
t_start = len(self.scheduler.timesteps) - num_inference_steps
680+
timesteps = self.scheduler.timesteps[t_start:]
681+
if hasattr(self.scheduler, "set_begin_index"):
682+
self.scheduler.set_begin_index(t_start)
678683
return timesteps, num_inference_steps
679684

680-
return timesteps, num_inference_steps - t_start
681-
682685
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
683686
def prepare_latents(
684687
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -897,22 +897,24 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
897897
if denoising_start is None:
898898
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
899899
t_start = max(num_inference_steps - init_timestep, 0)
900-
else:
901-
t_start = 0
902900

903-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
901+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
902+
if hasattr(self.scheduler, "set_begin_index"):
903+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
904+
905+
return timesteps, num_inference_steps - t_start
904906

905-
# Strength is irrelevant if we directly request a timestep to start at;
906-
# that is, strength is determined by the denoising_start instead.
907-
if denoising_start is not None:
907+
else:
908+
# Strength is irrelevant if we directly request a timestep to start at;
909+
# that is, strength is determined by the denoising_start instead.
908910
discrete_timestep_cutoff = int(
909911
round(
910912
self.scheduler.config.num_train_timesteps
911913
- (denoising_start * self.scheduler.config.num_train_timesteps)
912914
)
913915
)
914916

915-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
917+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
916918
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
917919
# if the scheduler is a 2nd order scheduler we might have to do +1
918920
# because `num_inference_steps` might be even given that every timestep
@@ -923,11 +925,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
923925
num_inference_steps = num_inference_steps + 1
924926

925927
# because t_n+1 >= t_n, we slice the timesteps starting from the end
926-
timesteps = timesteps[-num_inference_steps:]
928+
t_start = len(self.scheduler.timesteps) - num_inference_steps
929+
timesteps = self.scheduler.timesteps[t_start:]
930+
if hasattr(self.scheduler, "set_begin_index"):
931+
self.scheduler.set_begin_index(t_start)
927932
return timesteps, num_inference_steps
928933

929-
return timesteps, num_inference_steps - t_start
930-
931934
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
932935
def _get_add_time_ids(
933936
self,

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -640,22 +640,24 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
640640
if denoising_start is None:
641641
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
642642
t_start = max(num_inference_steps - init_timestep, 0)
643-
else:
644-
t_start = 0
645643

646-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
644+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
645+
if hasattr(self.scheduler, "set_begin_index"):
646+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
647+
648+
return timesteps, num_inference_steps - t_start
647649

648-
# Strength is irrelevant if we directly request a timestep to start at;
649-
# that is, strength is determined by the denoising_start instead.
650-
if denoising_start is not None:
650+
else:
651+
# Strength is irrelevant if we directly request a timestep to start at;
652+
# that is, strength is determined by the denoising_start instead.
651653
discrete_timestep_cutoff = int(
652654
round(
653655
self.scheduler.config.num_train_timesteps
654656
- (denoising_start * self.scheduler.config.num_train_timesteps)
655657
)
656658
)
657659

658-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
660+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
659661
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
660662
# if the scheduler is a 2nd order scheduler we might have to do +1
661663
# because `num_inference_steps` might be even given that every timestep
@@ -666,11 +668,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
666668
num_inference_steps = num_inference_steps + 1
667669

668670
# because t_n+1 >= t_n, we slice the timesteps starting from the end
669-
timesteps = timesteps[-num_inference_steps:]
671+
t_start = len(self.scheduler.timesteps) - num_inference_steps
672+
timesteps = self.scheduler.timesteps[t_start:]
673+
if hasattr(self.scheduler, "set_begin_index"):
674+
self.scheduler.set_begin_index(t_start)
670675
return timesteps, num_inference_steps
671676

672-
return timesteps, num_inference_steps - t_start
673-
674677
def prepare_latents(
675678
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
676679
):

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -901,22 +901,24 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
901901
if denoising_start is None:
902902
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
903903
t_start = max(num_inference_steps - init_timestep, 0)
904-
else:
905-
t_start = 0
906904

907-
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
905+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
906+
if hasattr(self.scheduler, "set_begin_index"):
907+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
908+
909+
return timesteps, num_inference_steps - t_start
908910

909-
# Strength is irrelevant if we directly request a timestep to start at;
910-
# that is, strength is determined by the denoising_start instead.
911-
if denoising_start is not None:
911+
else:
912+
# Strength is irrelevant if we directly request a timestep to start at;
913+
# that is, strength is determined by the denoising_start instead.
912914
discrete_timestep_cutoff = int(
913915
round(
914916
self.scheduler.config.num_train_timesteps
915917
- (denoising_start * self.scheduler.config.num_train_timesteps)
916918
)
917919
)
918920

919-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
921+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
920922
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
921923
# if the scheduler is a 2nd order scheduler we might have to do +1
922924
# because `num_inference_steps` might be even given that every timestep
@@ -927,11 +929,12 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
927929
num_inference_steps = num_inference_steps + 1
928930

929931
# because t_n+1 >= t_n, we slice the timesteps starting from the end
930-
timesteps = timesteps[-num_inference_steps:]
932+
t_start = len(self.scheduler.timesteps) - num_inference_steps
933+
timesteps = self.scheduler.timesteps[t_start:]
934+
if hasattr(self.scheduler, "set_begin_index"):
935+
self.scheduler.set_begin_index(t_start)
931936
return timesteps, num_inference_steps
932937

933-
return timesteps, num_inference_steps - t_start
934-
935938
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
936939
def _get_add_time_ids(
937940
self,

0 commit comments

Comments
 (0)