@@ -127,6 +127,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
127
127
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
128
128
steps_offset (`int`, defaults to 0):
129
129
An offset added to the inference steps, as required by some model families.
130
+ final_sigmas_type (`str`, defaults to `"zero"`):
131
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
132
+ is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
130
133
"""
131
134
132
135
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -153,6 +156,7 @@ def __init__(
153
156
use_karras_sigmas : Optional [bool ] = False ,
154
157
timestep_spacing : str = "linspace" ,
155
158
steps_offset : int = 0 ,
159
+ final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
156
160
):
157
161
if trained_betas is not None :
158
162
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -265,10 +269,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
265
269
sigmas = np .flip (sigmas ).copy ()
266
270
sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
267
271
timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
268
- sigmas = np .concatenate ([sigmas , sigmas [- 1 :]]).astype (np .float32 )
272
+ if self .config .final_sigmas_type == "sigma_min" :
273
+ sigma_last = sigmas [- 1 ]
274
+ elif self .config .final_sigmas_type == "zero" :
275
+ sigma_last = 0
276
+ else :
277
+ raise ValueError (
278
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
279
+ )
280
+ sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
269
281
else :
270
282
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
271
- sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
283
+ if self .config .final_sigmas_type == "sigma_min" :
284
+ sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
285
+ elif self .config .final_sigmas_type == "zero" :
286
+ sigma_last = 0
287
+ else :
288
+ raise ValueError (
289
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
290
+ )
272
291
sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
273
292
274
293
self .sigmas = torch .from_numpy (sigmas )
0 commit comments