Skip to content

Commit e73c056

Browse files
committed
Add sde handling to 3 order
1 parent 2e8e6cc commit e73c056

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ def multistep_dpm_solver_third_order_update(
806806
model_output_list: List[torch.Tensor],
807807
*args,
808808
sample: torch.Tensor = None,
809+
noise: Optional[torch.Tensor] = None,
809810
**kwargs,
810811
) -> torch.Tensor:
811812
"""
@@ -884,6 +885,15 @@ def multistep_dpm_solver_third_order_update(
884885
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
885886
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
886887
)
888+
elif self.config.algorithm_type == "sde-dpmsolver++":
889+
assert noise is not None
890+
x_t = (
891+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
892+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
893+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
894+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h)**2 - 0.5)) * D2
895+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
896+
)
887897
return x_t
888898

889899
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -990,7 +1000,7 @@ def step(
9901000
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
9911001
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
9921002
else:
993-
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1003+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
9941004

9951005
if self.lower_order_nums < self.config.solver_order:
9961006
self.lower_order_nums += 1

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ def singlestep_dpm_solver_third_order_update(
733733
model_output_list: List[torch.Tensor],
734734
*args,
735735
sample: torch.Tensor = None,
736+
noise: Optional[torch.Tensor] = None,
736737
**kwargs,
737738
) -> torch.Tensor:
738739
"""
@@ -830,6 +831,23 @@ def singlestep_dpm_solver_third_order_update(
830831
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
831832
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
832833
)
834+
elif self.config.algorithm_type == "sde-dpmsolver++":
835+
assert noise is not None
836+
if self.config.solver_type == "midpoint":
837+
x_t = (
838+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
839+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
840+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
841+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
842+
)
843+
elif self.config.solver_type == "heun":
844+
x_t = (
845+
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
846+
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
847+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
848+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h)**2 - 0.5)) * D2
849+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
850+
)
833851
return x_t
834852

835853
def singlestep_dpm_solver_update(
@@ -891,7 +909,7 @@ def singlestep_dpm_solver_update(
891909
elif order == 2:
892910
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
893911
elif order == 3:
894-
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
912+
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
895913
else:
896914
raise ValueError(f"Order must be 1, 2, 3, got {order}")
897915

0 commit comments

Comments
 (0)