Skip to content

Commit 2d01740

Browse files
committed
support edm dpmsolver multistep
1 parent 06e852d commit 2d01740

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,35 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
1616

1717
import math
18+
from dataclasses import dataclass
1819
from typing import List, Optional, Tuple, Union
1920

2021
import numpy as np
2122
import torch
2223

2324
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..utils import BaseOutput
2426
from ..utils.torch_utils import randn_tensor
25-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
27+
from .scheduling_utils import SchedulerMixin
28+
29+
30+
@dataclass
31+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMDPMSolverMultistep
32+
class EDMDPMSolverMultistepSchedulerOutput(BaseOutput):
33+
"""
34+
Output class for the scheduler's `step` function output.
35+
36+
Args:
37+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39+
denoising loop.
40+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42+
`pred_original_sample` can be used to preview progress or for guidance.
43+
"""
44+
45+
prev_sample: torch.Tensor
46+
pred_original_sample: Optional[torch.Tensor] = None
2647

2748

2849
class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
@@ -593,7 +614,8 @@ def step(
593614
sample: torch.Tensor,
594615
generator=None,
595616
return_dict: bool = True,
596-
) -> Union[SchedulerOutput, Tuple]:
617+
pred_original_sample: Optional[torch.Tensor] = None,
618+
) -> Union[EDMDPMSolverMultistepSchedulerOutput, Tuple]:
597619
"""
598620
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
599621
the multistep DPMSolver.
@@ -608,12 +630,14 @@ def step(
608630
generator (`torch.Generator`, *optional*):
609631
A random number generator.
610632
return_dict (`bool`):
611-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
633+
Whether or not to return a
634+
[`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or a `tuple`.
612635
613636
Returns:
614-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
615-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
616-
tuple is returned where the first element is the sample tensor.
637+
[`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or `tuple`:
638+
If return_dict is `True`,
639+
[`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] is returned,
640+
otherwise a tuple is returned where the first element is the sample tensor.
617641
618642
"""
619643
if self.num_inference_steps is None:
@@ -634,7 +658,12 @@ def step(
634658
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
635659
)
636660

637-
model_output = self.convert_model_output(model_output, sample=sample)
661+
if pred_original_sample is None:
662+
model_output = self.convert_model_output(model_output, sample=sample)
663+
else:
664+
model_output = pred_original_sample
665+
# TODO: thresholding is not handled in this case, but probably not needed either for Cosmos
666+
638667
for i in range(self.config.solver_order - 1):
639668
self.model_outputs[i] = self.model_outputs[i + 1]
640669
self.model_outputs[-1] = model_output
@@ -662,7 +691,7 @@ def step(
662691
if not return_dict:
663692
return (prev_sample,)
664693

665-
return SchedulerOutput(prev_sample=prev_sample)
694+
return EDMDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, pred_original_sample=model_output)
666695

667696
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
668697
def add_noise(

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
@dataclass
31-
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
31+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMEuler
3232
class EDMEulerSchedulerOutput(BaseOutput):
3333
"""
3434
Output class for the scheduler's `step` function output.

0 commit comments

Comments
 (0)