diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 610e8d2d765c..a7cc4209fec4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np @@ -20,14 +21,33 @@ import torchsde from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import is_scipy_available -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput, is_scipy_available +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin if is_scipy_available(): import scipy.stats +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE +class DPMSolverSDESchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" @@ -510,7 +530,7 @@ def step( sample: Union[torch.Tensor, np.ndarray], return_dict: bool = True, s_noise: float = 1.0, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DPMSolverSDESchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). @@ -522,15 +542,16 @@ def step( The current discrete timestep in the diffusion chain. sample (`torch.Tensor` or `np.ndarray`): A current instance of a sample created by the diffusion process. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or + tuple. s_noise (`float`, *optional*, defaults to 1.0): Scaling factor for noise added to the sample. Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) @@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor: self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) - return SchedulerOutput(prev_sample=prev_sample) + return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index cb995df4af59..63f38e86ab45 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -13,20 +13,40 @@ # limitations under the License. import math +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import is_scipy_available -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput, is_scipy_available +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin if is_scipy_available(): import scipy.stats +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete +class HeunDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -455,7 +475,7 @@ def step( timestep: Union[float, torch.Tensor], sample: Union[torch.Tensor, np.ndarray], return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[HeunDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). @@ -468,12 +488,13 @@ def step( sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or + tuple. Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) @@ -544,9 +565,12 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) - return SchedulerOutput(prev_sample=prev_sample) + return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index b1ec244e5a79..6dfc024e221c 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -13,21 +13,41 @@ # limitations under the License. import math +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import is_scipy_available +from ..utils import BaseOutput, is_scipy_available from ..utils.torch_utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin if is_scipy_available(): import scipy.stats +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete +class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -459,7 +479,7 @@ def step( sample: Union[torch.Tensor, np.ndarray], generator: Optional[torch.Generator] = None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). @@ -474,12 +494,14 @@ def step( generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + Whether or not to return a + [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple. Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) @@ -548,9 +570,14 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) - return SchedulerOutput(prev_sample=prev_sample) + return KDPM2AncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 422fe40556f0..bf3b9f1437d2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -13,20 +13,40 @@ # limitations under the License. import math +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import is_scipy_available -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput, is_scipy_available +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin if is_scipy_available(): import scipy.stats +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete +class KDPM2DiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, @@ -443,7 +463,7 @@ def step( timestep: Union[float, torch.Tensor], sample: Union[torch.Tensor, np.ndarray], return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). @@ -456,12 +476,13 @@ def step( sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or + tuple. Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) @@ -523,9 +544,12 @@ def step( prev_sample = sample + derivative * dt if not return_dict: - return (prev_sample,) + return ( + prev_sample, + pred_original_sample, + ) - return SchedulerOutput(prev_sample=prev_sample) + return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise(