@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108108            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and 
109109            `algorithm_type="dpmsolver++"`. 
110110        algorithm_type (`str`, defaults to `dpmsolver++`): 
111-             Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver`  or `sde- dpmsolver++`. The 
111+             Algorithm type for the solver; can be `dpmsolver`  or `dpmsolver++`. The 
112112            `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) 
113113            paper, and the `dpmsolver++` type implements the algorithms in the 
114114            [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or 
@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
122122        use_karras_sigmas (`bool`, *optional*, defaults to `False`): 
123123            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 
124124            the sigmas are determined according to a sequence of noise levels {σi}. 
125+         final_sigmas_type (`str`, *optional*, defaults to `"zero"`): 
126+             The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma 
127+             is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. 
125128        lambda_min_clipped (`float`, defaults to `-inf`): 
126129            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the 
127130            cosine (`squaredcos_cap_v2`) noise schedule. 
@@ -150,9 +153,14 @@ def __init__(
150153        solver_type : str  =  "midpoint" ,
151154        lower_order_final : bool  =  True ,
152155        use_karras_sigmas : Optional [bool ] =  False ,
156+         final_sigmas_type : Optional [str ] =  "zero" ,  # "zero", "sigma_min" 
153157        lambda_min_clipped : float  =  - float ("inf" ),
154158        variance_type : Optional [str ] =  None ,
155159    ):
160+         if  algorithm_type  ==  "dpmsolver" :
161+             deprecation_message  =  "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" 
162+             deprecate ("algorithm_types=dpmsolver" , "1.0.0" , deprecation_message )
163+ 
156164        if  trained_betas  is  not None :
157165            self .betas  =  torch .tensor (trained_betas , dtype = torch .float32 )
158166        elif  beta_schedule  ==  "linear" :
@@ -189,6 +197,11 @@ def __init__(
189197            else :
190198                raise  NotImplementedError (f"{ solver_type } { self .__class__ }  )
191199
200+         if  algorithm_type  !=  "dpmsolver++"  and  final_sigmas_type  ==  "zero" :
201+             raise  ValueError (
202+                 f"`final_sigmas_type` { final_sigmas_type } { algorithm_type }  
203+             )
204+ 
192205        # setable values 
193206        self .num_inference_steps  =  None 
194207        timesteps  =  np .linspace (0 , num_train_timesteps  -  1 , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
@@ -267,11 +280,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
267280            sigmas  =  np .flip (sigmas ).copy ()
268281            sigmas  =  self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
269282            timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ]).round ()
270-             sigmas  =  np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
271283        else :
272284            sigmas  =  np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
285+ 
286+         if  self .config .final_sigmas_type  ==  "sigma_min" :
273287            sigma_last  =  ((1  -  self .alphas_cumprod [0 ]) /  self .alphas_cumprod [0 ]) **  0.5 
274-             sigmas  =  np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
288+         elif  self .config .final_sigmas_type  ==  "zero" :
289+             sigma_last  =  0 
290+         else :
291+             raise  ValueError (
292+                 f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got { self .config .final_sigmas_type }  
293+             )
294+         sigmas  =  np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
275295
276296        self .sigmas  =  torch .from_numpy (sigmas ).to (device = device )
277297
@@ -285,6 +305,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
285305            )
286306            self .register_to_config (lower_order_final = True )
287307
308+         if  not  self .config .lower_order_final  and  self .config .final_sigmas_type  ==  "zero" :
309+             logger .warn (
310+                 " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." 
311+             )
312+             self .register_to_config (lower_order_final = True )
313+ 
288314        self .order_list  =  self .get_order_list (num_inference_steps )
289315
290316        # add an index counter for schedulers that allow duplicated timesteps 
0 commit comments