1313# limitations under the License.
1414
1515import math
16+ from dataclasses import dataclass
1617from typing import List , Optional , Tuple , Union
1718
1819import numpy as np
1920import torch
2021import torchsde
2122
2223from ..configuration_utils import ConfigMixin , register_to_config
23- from ..utils import is_scipy_available
24- from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
24+ from ..utils import BaseOutput , is_scipy_available
25+ from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
2526
2627
2728if is_scipy_available ():
2829 import scipy .stats
2930
3031
32+ @dataclass
33+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
34+ class DPMSolverSDESchedulerOutput (BaseOutput ):
35+ """
36+ Output class for the scheduler's `step` function output.
37+
38+ Args:
39+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41+ denoising loop.
42+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44+ `pred_original_sample` can be used to preview progress or for guidance.
45+ """
46+
47+ prev_sample : torch .Tensor
48+ pred_original_sample : Optional [torch .Tensor ] = None
49+
50+
3151class BatchedBrownianTree :
3252 """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
3353
@@ -510,7 +530,7 @@ def step(
510530 sample : Union [torch .Tensor , np .ndarray ],
511531 return_dict : bool = True ,
512532 s_noise : float = 1.0 ,
513- ) -> Union [SchedulerOutput , Tuple ]:
533+ ) -> Union [DPMSolverSDESchedulerOutput , Tuple ]:
514534 """
515535 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
516536 process from the learned model outputs (most often the predicted noise).
@@ -522,15 +542,16 @@ def step(
522542 The current discrete timestep in the diffusion chain.
523543 sample (`torch.Tensor` or `np.ndarray`):
524544 A current instance of a sample created by the diffusion process.
525- return_dict (`bool`, *optional*, defaults to `True`):
526- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
545+ return_dict (`bool`):
546+ Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
547+ tuple.
527548 s_noise (`float`, *optional*, defaults to 1.0):
528549 Scaling factor for noise added to the sample.
529550
530551 Returns:
531- [`~schedulers.scheduling_utils.SchedulerOutput `] or `tuple`:
532- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput `] is returned, otherwise a
533- tuple is returned where the first element is the sample tensor.
552+ [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput `] or `tuple`:
553+ If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput `] is
554+ returned, otherwise a tuple is returned where the first element is the sample tensor.
534555 """
535556 if self .step_index is None :
536557 self ._init_step_index (timestep )
@@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
610631 self ._step_index += 1
611632
612633 if not return_dict :
613- return (prev_sample ,)
634+ return (
635+ prev_sample ,
636+ pred_original_sample ,
637+ )
614638
615- return SchedulerOutput (prev_sample = prev_sample )
639+ return DPMSolverSDESchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_original_sample )
616640
617641 # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
618642 def add_noise (
0 commit comments