@@ -733,6 +733,7 @@ def singlestep_dpm_solver_third_order_update(
733733        model_output_list : List [torch .Tensor ],
734734        * args ,
735735        sample : torch .Tensor  =  None ,
736+         noise : Optional [torch .Tensor ] =  None ,
736737        ** kwargs ,
737738    ) ->  torch .Tensor :
738739        """ 
@@ -830,6 +831,23 @@ def singlestep_dpm_solver_third_order_update(
830831                    -  (sigma_t  *  ((torch .exp (h ) -  1.0 ) /  h  -  1.0 )) *  D1 
831832                    -  (sigma_t  *  ((torch .exp (h ) -  1.0  -  h ) /  h ** 2  -  0.5 )) *  D2 
832833                )
834+         elif  self .config .algorithm_type  ==  "sde-dpmsolver++" :
835+             assert  noise  is  not None 
836+             if  self .config .solver_type  ==  "midpoint" :
837+                 x_t  =  (
838+                     (sigma_t  /  sigma_s2  *  torch .exp (- h )) *  sample 
839+                     +  (alpha_t  *  (1.0  -  torch .exp (- 2.0  *  h ))) *  D0 
840+                     +  (alpha_t  *  ((1.0  -  torch .exp (- 2.0  *  h )) /  (- 2.0  *  h ) +  1.0 )) *  D1_1 
841+                     +  sigma_t  *  torch .sqrt (1.0  -  torch .exp (- 2  *  h )) *  noise 
842+                 )
843+             elif  self .config .solver_type  ==  "heun" :
844+                 x_t  =  (
845+                     (sigma_t  /  sigma_s2  *  torch .exp (- h )) *  sample 
846+                     +  (alpha_t  *  (1.0  -  torch .exp (- 2.0  *  h ))) *  D0 
847+                     +  (alpha_t  *  ((1.0  -  torch .exp (- 2.0  *  h )) /  (- 2.0  *  h ) +  1.0 )) *  D1 
848+                     +  (alpha_t  *  ((1.0  -  torch .exp (- 2.0  *  h ) +  (- 2.0  *  h )) /  (- 2.0  *  h )** 2  -  0.5 )) *  D2 
849+                     +  sigma_t  *  torch .sqrt (1.0  -  torch .exp (- 2  *  h )) *  noise 
850+                 )
833851        return  x_t 
834852
835853    def  singlestep_dpm_solver_update (
@@ -891,7 +909,7 @@ def singlestep_dpm_solver_update(
891909        elif  order  ==  2 :
892910            return  self .singlestep_dpm_solver_second_order_update (model_output_list , sample = sample , noise = noise )
893911        elif  order  ==  3 :
894-             return  self .singlestep_dpm_solver_third_order_update (model_output_list , sample = sample )
912+             return  self .singlestep_dpm_solver_third_order_update (model_output_list , sample = sample ,  noise = noise )
895913        else :
896914            raise  ValueError (f"Order must be 1, 2, 3, got { order }  )
897915
0 commit comments