1515# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm 
1616
1717import  math 
18+ from  dataclasses  import  dataclass 
1819from  typing  import  List , Optional , Tuple , Union 
1920
2021import  numpy  as  np 
2122import  torch 
2223
2324from  ..configuration_utils  import  ConfigMixin , register_to_config 
25+ from  ..utils  import  BaseOutput 
2426from  ..utils .torch_utils  import  randn_tensor 
25- from  .scheduling_utils  import  SchedulerMixin , SchedulerOutput 
27+ from  .scheduling_utils  import  SchedulerMixin 
28+ 
29+ 
30+ @dataclass  
31+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EDMDPMSolverMultistep 
32+ class  EDMDPMSolverMultistepSchedulerOutput (BaseOutput ):
33+     """ 
34+     Output class for the scheduler's `step` function output. 
35+ 
36+     Args: 
37+         prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 
38+             Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 
39+             denoising loop. 
40+         pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): 
41+             The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 
42+             `pred_original_sample` can be used to preview progress or for guidance. 
43+     """ 
44+ 
45+     prev_sample : torch .Tensor 
46+     pred_original_sample : Optional [torch .Tensor ] =  None 
2647
2748
2849class  EDMDPMSolverMultistepScheduler (SchedulerMixin , ConfigMixin ):
@@ -593,7 +614,8 @@ def step(
593614        sample : torch .Tensor ,
594615        generator = None ,
595616        return_dict : bool  =  True ,
596-     ) ->  Union [SchedulerOutput , Tuple ]:
617+         pred_original_sample : Optional [torch .Tensor ] =  None ,
618+     ) ->  Union [EDMDPMSolverMultistepSchedulerOutput , Tuple ]:
597619        """ 
598620        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with 
599621        the multistep DPMSolver. 
@@ -608,12 +630,14 @@ def step(
608630            generator (`torch.Generator`, *optional*): 
609631                A random number generator. 
610632            return_dict (`bool`): 
611-                 Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. 
633+                 Whether or not to return a 
634+                 [`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or a `tuple`. 
612635
613636        Returns: 
614-             [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: 
615-                 If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a 
616-                 tuple is returned where the first element is the sample tensor. 
637+             [`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] or `tuple`: 
638+                 If return_dict is `True`, 
639+                 [`~schedulers.scheduling_edm_dpmsolver_multistep.EDMDPMSolverMultistepSchedulerOutput`] is returned, 
640+                 otherwise a tuple is returned where the first element is the sample tensor. 
617641
618642        """ 
619643        if  self .num_inference_steps  is  None :
@@ -634,7 +658,12 @@ def step(
634658            (self .step_index  ==  len (self .timesteps ) -  2 ) and  self .config .lower_order_final  and  len (self .timesteps ) <  15 
635659        )
636660
637-         model_output  =  self .convert_model_output (model_output , sample = sample )
661+         if  pred_original_sample  is  None :
662+             model_output  =  self .convert_model_output (model_output , sample = sample )
663+         else :
664+             model_output  =  pred_original_sample 
665+             # TODO: thresholding is not handled in this case, but probably not needed either for Cosmos 
666+ 
638667        for  i  in  range (self .config .solver_order  -  1 ):
639668            self .model_outputs [i ] =  self .model_outputs [i  +  1 ]
640669        self .model_outputs [- 1 ] =  model_output 
@@ -662,7 +691,7 @@ def step(
662691        if  not  return_dict :
663692            return  (prev_sample ,)
664693
665-         return  SchedulerOutput (prev_sample = prev_sample )
694+         return  EDMDPMSolverMultistepSchedulerOutput (prev_sample = prev_sample ,  pred_original_sample = model_output )
666695
667696    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise 
668697    def  add_noise (
0 commit comments