diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index d25947d8d331..0617cc44d75a 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): Video](https://imagen.research.google/video/paper.pdf) paper). rho (`float`, *optional*, defaults to 7.0): The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [] @@ -92,6 +95,7 @@ def __init__( num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, + final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): if sigma_schedule not in ["karras", "exponential"]: raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") @@ -99,15 +103,24 @@ def __init__( # setable values self.num_inference_steps = None - ramp = torch.linspace(0, 1, num_train_timesteps) + sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps if sigma_schedule == "karras": - sigmas = self._compute_karras_sigmas(ramp) + sigmas = self._compute_karras_sigmas(sigmas) elif sigma_schedule == "exponential": - sigmas = self._compute_exponential_sigmas(ramp) + sigmas = self._compute_exponential_sigmas(sigmas) self.timesteps = self.precondition_noise(sigmas) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)]) self.is_scale_input_called = False @@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T self.is_scale_input_called = True return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[Union[torch.Tensor, List[float]]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`Union[torch.Tensor, List[float]]`, *optional*): + Custom sigmas to use for the denoising process. If not defined, the default behavior when + `num_inference_steps` is passed will be used. """ self.num_inference_steps = num_inference_steps - ramp = torch.linspace(0, 1, self.num_inference_steps) + if sigmas is None: + sigmas = torch.linspace(0, 1, self.num_inference_steps) + elif isinstance(sigmas, float): + sigmas = torch.tensor(sigmas, dtype=torch.float32) + else: + sigmas = sigmas if self.config.sigma_schedule == "karras": - sigmas = self._compute_karras_sigmas(ramp) + sigmas = self._compute_karras_sigmas(sigmas) elif self.config.sigma_schedule == "exponential": - sigmas = self._compute_exponential_sigmas(ramp) + sigmas = self._compute_exponential_sigmas(sigmas) sigmas = sigmas.to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)]) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication