Skip to content

Commit f0c8156

Browse files
authored
Add final_sigma_zero to UniPCMultistep (#7517)
* Add `final_sigma_zero` to UniPCMultistep Effectively the same trick as DDIM's `set_alpha_to_one` and DPM's `final_sigma_type='zero'`. Currently False by default but maybe this should be True? * `final_sigma_zero: bool` -> `final_sigmas_type: str` Should 1:1 match DPM Multistep now. * Set `final_sigmas_type='sigma_min'` in UniPC UTs
1 parent 9d20ed3 commit f0c8156

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
127127
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
128128
steps_offset (`int`, defaults to 0):
129129
An offset added to the inference steps, as required by some model families.
130+
final_sigmas_type (`str`, defaults to `"zero"`):
131+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
132+
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
130133
"""
131134

132135
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -153,6 +156,7 @@ def __init__(
153156
use_karras_sigmas: Optional[bool] = False,
154157
timestep_spacing: str = "linspace",
155158
steps_offset: int = 0,
159+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
156160
):
157161
if trained_betas is not None:
158162
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -265,10 +269,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
265269
sigmas = np.flip(sigmas).copy()
266270
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
267271
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
268-
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
272+
if self.config.final_sigmas_type == "sigma_min":
273+
sigma_last = sigmas[-1]
274+
elif self.config.final_sigmas_type == "zero":
275+
sigma_last = 0
276+
else:
277+
raise ValueError(
278+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
279+
)
280+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
269281
else:
270282
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
271-
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
283+
if self.config.final_sigmas_type == "sigma_min":
284+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
285+
elif self.config.final_sigmas_type == "zero":
286+
sigma_last = 0
287+
else:
288+
raise ValueError(
289+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
290+
)
272291
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
273292

274293
self.sigmas = torch.from_numpy(sigmas)

tests/schedulers/test_scheduler_unipc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get_scheduler_config(self, **kwargs):
2424
"beta_schedule": "linear",
2525
"solver_order": 2,
2626
"solver_type": "bh2",
27+
"final_sigmas_type": "sigma_min",
2728
}
2829

2930
config.update(**kwargs)

0 commit comments

Comments
 (0)