Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,41 @@
# limitations under the License.

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchsde

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin


if is_scipy_available():
import scipy.stats


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
class DPMSolverSDESchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.

Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None


class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""

Expand Down Expand Up @@ -510,7 +530,7 @@ def step(
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
s_noise: float = 1.0,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -522,15 +542,16 @@ def step(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor` or `np.ndarray`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
tuple.
s_noise (`float`, *optional*, defaults to 1.0):
Scaling factor for noise added to the sample.

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
Expand Down Expand Up @@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
self._step_index += 1

if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)

return SchedulerOutput(prev_sample=prev_sample)
return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
Expand Down
42 changes: 33 additions & 9 deletions src/diffusers/schedulers/scheduling_heun_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,40 @@
# limitations under the License.

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin


if is_scipy_available():
import scipy.stats


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
class HeunDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.

Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
Expand Down Expand Up @@ -455,7 +475,7 @@ def step(
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -468,12 +488,13 @@ def step(
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
tuple.

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
Expand Down Expand Up @@ -544,9 +565,12 @@ def step(
self._step_index += 1

if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)

return SchedulerOutput(prev_sample=prev_sample)
return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
Expand Down
45 changes: 36 additions & 9 deletions src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,41 @@
# limitations under the License.

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from ..utils import BaseOutput, is_scipy_available
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin


if is_scipy_available():
import scipy.stats


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.

Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
Expand Down Expand Up @@ -459,7 +479,7 @@ def step(
sample: Union[torch.Tensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -474,12 +494,14 @@ def step(
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Whether or not to return a
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
Expand Down Expand Up @@ -548,9 +570,14 @@ def step(
self._step_index += 1

if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)

return SchedulerOutput(prev_sample=prev_sample)
return KDPM2AncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
Expand Down
42 changes: 33 additions & 9 deletions src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,40 @@
# limitations under the License.

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin


if is_scipy_available():
import scipy.stats


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete
class KDPM2DiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.

Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
Expand Down Expand Up @@ -443,7 +463,7 @@ def step(
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -456,12 +476,13 @@ def step(
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or
tuple.

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
Expand Down Expand Up @@ -523,9 +544,12 @@ def step(
prev_sample = sample + derivative * dt

if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)

return SchedulerOutput(prev_sample=prev_sample)
return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
Expand Down
Loading