1313# limitations under the License. 
1414
1515import  math 
16+ from  dataclasses  import  dataclass 
1617from  typing  import  List , Optional , Tuple , Union 
1718
1819import  numpy  as  np 
1920import  torch 
2021import  torchsde 
2122
2223from  ..configuration_utils  import  ConfigMixin , register_to_config 
23- from  ..utils  import  is_scipy_available 
24- from  .scheduling_utils  import  KarrasDiffusionSchedulers , SchedulerMixin ,  SchedulerOutput 
24+ from  ..utils  import  BaseOutput ,  is_scipy_available 
25+ from  .scheduling_utils  import  KarrasDiffusionSchedulers , SchedulerMixin 
2526
2627
2728if  is_scipy_available ():
2829    import  scipy .stats 
2930
3031
32+ @dataclass  
33+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE 
34+ class  DPMSolverSDESchedulerOutput (BaseOutput ):
35+     """ 
36+     Output class for the scheduler's `step` function output. 
37+ 
38+     Args: 
39+         prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 
40+             Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 
41+             denoising loop. 
42+         pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 
43+             The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 
44+             `pred_original_sample` can be used to preview progress or for guidance. 
45+     """ 
46+ 
47+     prev_sample : torch .Tensor 
48+     pred_original_sample : Optional [torch .Tensor ] =  None 
49+ 
50+ 
3151class  BatchedBrownianTree :
3252    """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" 
3353
@@ -510,7 +530,7 @@ def step(
510530        sample : Union [torch .Tensor , np .ndarray ],
511531        return_dict : bool  =  True ,
512532        s_noise : float  =  1.0 ,
513-     ) ->  Union [SchedulerOutput , Tuple ]:
533+     ) ->  Union [DPMSolverSDESchedulerOutput , Tuple ]:
514534        """ 
515535        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 
516536        process from the learned model outputs (most often the predicted noise). 
@@ -522,15 +542,16 @@ def step(
522542                The current discrete timestep in the diffusion chain. 
523543            sample (`torch.Tensor` or `np.ndarray`): 
524544                A current instance of a sample created by the diffusion process. 
525-             return_dict (`bool`, *optional*, defaults to `True`): 
526-                 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. 
545+             return_dict (`bool`): 
546+                 Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or 
547+                 tuple. 
527548            s_noise (`float`, *optional*, defaults to 1.0): 
528549                Scaling factor for noise added to the sample. 
529550
530551        Returns: 
531-             [`~schedulers.scheduling_utils.SchedulerOutput `] or `tuple`: 
532-                 If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput `] is returned, otherwise a  
533-                 tuple is returned where the first element is the sample tensor. 
552+             [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput `] or `tuple`: 
553+                 If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput `] is 
554+                 returned, otherwise a  tuple is returned where the first element is the sample tensor. 
534555        """ 
535556        if  self .step_index  is  None :
536557            self ._init_step_index (timestep )
@@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
610631        self ._step_index  +=  1 
611632
612633        if  not  return_dict :
613-             return  (prev_sample ,)
634+             return  (
635+                 prev_sample ,
636+                 pred_original_sample ,
637+             )
614638
615-         return  SchedulerOutput (prev_sample = prev_sample )
639+         return  DPMSolverSDESchedulerOutput (prev_sample = prev_sample ,  pred_original_sample = pred_original_sample )
616640
617641    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise 
618642    def  add_noise (
0 commit comments