Skip to content

Commit af7efa0

Browse files
StAlKeR7779hlky
authored andcommitted
DPM++ third order fixes (#9104)
* Fix wrong output on 3n-1 steps count * Add sde handling to 3 order * make * copies --------- Co-authored-by: hlky <[email protected]>
1 parent a9f8f2e commit af7efa0

File tree

4 files changed

+45
-3
lines changed

4 files changed

+45
-3
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@
338338
"StableDiffusion3ControlNetPipeline",
339339
"StableDiffusion3Img2ImgPipeline",
340340
"StableDiffusion3InpaintPipeline",
341-
"StableDiffusion3PAGPipeline",
342341
"StableDiffusion3PAGImg2ImgPipeline",
342+
"StableDiffusion3PAGPipeline",
343343
"StableDiffusion3Pipeline",
344344
"StableDiffusionAdapterPipeline",
345345
"StableDiffusionAttendAndExcitePipeline",

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ def multistep_dpm_solver_third_order_update(
889889
model_output_list: List[torch.Tensor],
890890
*args,
891891
sample: torch.Tensor = None,
892+
noise: Optional[torch.Tensor] = None,
892893
**kwargs,
893894
) -> torch.Tensor:
894895
"""
@@ -967,6 +968,15 @@ def multistep_dpm_solver_third_order_update(
967968
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
968969
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
969970
)
971+
elif self.config.algorithm_type == "sde-dpmsolver++":
972+
assert noise is not None
973+
x_t = (
974+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
975+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
976+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
977+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
978+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
979+
)
970980
return x_t
971981

972982
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -1073,7 +1083,7 @@ def step(
10731083
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
10741084
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
10751085
else:
1076-
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1086+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
10771087

10781088
if self.lower_order_nums < self.config.solver_order:
10791089
self.lower_order_nums += 1

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ def multistep_dpm_solver_third_order_update(
764764
model_output_list: List[torch.Tensor],
765765
*args,
766766
sample: torch.Tensor = None,
767+
noise: Optional[torch.Tensor] = None,
767768
**kwargs,
768769
) -> torch.Tensor:
769770
"""
@@ -842,6 +843,15 @@ def multistep_dpm_solver_third_order_update(
842843
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
843844
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
844845
)
846+
elif self.config.algorithm_type == "sde-dpmsolver++":
847+
assert noise is not None
848+
x_t = (
849+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
850+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
851+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
852+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
853+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
854+
)
845855
return x_t
846856

847857
def _init_step_index(self, timestep):

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
264264
orders = [1, 2] * (steps // 2)
265265
elif order == 1:
266266
orders = [1] * steps
267+
268+
if self.config.final_sigmas_type == "zero":
269+
orders[-1] = 1
270+
267271
return orders
268272

269273
@property
@@ -812,6 +816,7 @@ def singlestep_dpm_solver_third_order_update(
812816
model_output_list: List[torch.Tensor],
813817
*args,
814818
sample: torch.Tensor = None,
819+
noise: Optional[torch.Tensor] = None,
815820
**kwargs,
816821
) -> torch.Tensor:
817822
"""
@@ -909,6 +914,23 @@ def singlestep_dpm_solver_third_order_update(
909914
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
910915
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
911916
)
917+
elif self.config.algorithm_type == "sde-dpmsolver++":
918+
assert noise is not None
919+
if self.config.solver_type == "midpoint":
920+
x_t = (
921+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
922+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
923+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
924+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
925+
)
926+
elif self.config.solver_type == "heun":
927+
x_t = (
928+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
929+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
930+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
931+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
932+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
933+
)
912934
return x_t
913935

914936
def singlestep_dpm_solver_update(
@@ -970,7 +992,7 @@ def singlestep_dpm_solver_update(
970992
elif order == 2:
971993
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
972994
elif order == 3:
973-
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
995+
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
974996
else:
975997
raise ValueError(f"Order must be 1, 2, 3, got {order}")
976998

0 commit comments

Comments
 (0)