Skip to content

Commit 4d687a2

Browse files
StAlKeR7779hlky
authored andcommitted
DPM++ third order fixes (huggingface#9104)
* Fix wrong output on 3n-1 steps count * Add sde handling to 3 order * make * copies --------- Co-authored-by: hlky <[email protected]>
1 parent 25f515f commit 4d687a2

File tree

4 files changed

+45
-3
lines changed

4 files changed

+45
-3
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@
343343
"StableDiffusion3ControlNetPipeline",
344344
"StableDiffusion3Img2ImgPipeline",
345345
"StableDiffusion3InpaintPipeline",
346-
"StableDiffusion3PAGPipeline",
347346
"StableDiffusion3PAGImg2ImgPipeline",
347+
"StableDiffusion3PAGPipeline",
348348
"StableDiffusion3Pipeline",
349349
"StableDiffusionAdapterPipeline",
350350
"StableDiffusionAttendAndExcitePipeline",

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,7 @@ def multistep_dpm_solver_third_order_update(
903903
model_output_list: List[torch.Tensor],
904904
*args,
905905
sample: torch.Tensor = None,
906+
noise: Optional[torch.Tensor] = None,
906907
**kwargs,
907908
) -> torch.Tensor:
908909
"""
@@ -981,6 +982,15 @@ def multistep_dpm_solver_third_order_update(
981982
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
982983
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
983984
)
985+
elif self.config.algorithm_type == "sde-dpmsolver++":
986+
assert noise is not None
987+
x_t = (
988+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
989+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
990+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
991+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
992+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
993+
)
984994
return x_t
985995

986996
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -1087,7 +1097,7 @@ def step(
10871097
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
10881098
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
10891099
else:
1090-
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1100+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
10911101

10921102
if self.lower_order_nums < self.config.solver_order:
10931103
self.lower_order_nums += 1

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ def multistep_dpm_solver_third_order_update(
764764
model_output_list: List[torch.Tensor],
765765
*args,
766766
sample: torch.Tensor = None,
767+
noise: Optional[torch.Tensor] = None,
767768
**kwargs,
768769
) -> torch.Tensor:
769770
"""
@@ -842,6 +843,15 @@ def multistep_dpm_solver_third_order_update(
842843
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
843844
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
844845
)
846+
elif self.config.algorithm_type == "sde-dpmsolver++":
847+
assert noise is not None
848+
x_t = (
849+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
850+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
851+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
852+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
853+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
854+
)
845855
return x_t
846856

847857
def _init_step_index(self, timestep):

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
264264
orders = [1, 2] * (steps // 2)
265265
elif order == 1:
266266
orders = [1] * steps
267+
268+
if self.config.final_sigmas_type == "zero":
269+
orders[-1] = 1
270+
267271
return orders
268272

269273
@property
@@ -812,6 +816,7 @@ def singlestep_dpm_solver_third_order_update(
812816
model_output_list: List[torch.Tensor],
813817
*args,
814818
sample: torch.Tensor = None,
819+
noise: Optional[torch.Tensor] = None,
815820
**kwargs,
816821
) -> torch.Tensor:
817822
"""
@@ -909,6 +914,23 @@ def singlestep_dpm_solver_third_order_update(
909914
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
910915
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
911916
)
917+
elif self.config.algorithm_type == "sde-dpmsolver++":
918+
assert noise is not None
919+
if self.config.solver_type == "midpoint":
920+
x_t = (
921+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
922+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
923+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
924+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
925+
)
926+
elif self.config.solver_type == "heun":
927+
x_t = (
928+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
929+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
930+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
931+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
932+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
933+
)
912934
return x_t
913935

914936
def singlestep_dpm_solver_update(
@@ -970,7 +992,7 @@ def singlestep_dpm_solver_update(
970992
elif order == 2:
971993
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
972994
elif order == 3:
973-
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
995+
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
974996
else:
975997
raise ValueError(f"Order must be 1, 2, 3, got {order}")
976998

0 commit comments

Comments
 (0)