diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 162a34bd2774..f3e871fe7289 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -216,6 +216,7 @@ def __init__( rescale_betas_zero_snr: bool = False, use_dynamic_shifting: bool = False, time_shift_type: str = "exponential", + shift_terminal: Optional[float] = None, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -235,6 +236,8 @@ def __init__( self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + if shift_terminal is not None and not use_flow_sigmas: + raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) @@ -303,7 +306,12 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index def set_timesteps( - self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + mu: Optional[float] = None, + sigmas: Optional[List[float]] = None, + timesteps: Optional[List[float]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -314,10 +322,23 @@ def set_timesteps( device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None or timesteps is not None: + if not self.config.use_flow_sigmas: + raise ValueError( + "Passing `sigmas` or `timesteps` is only supported when `use_flow_sigmas=True`. " + "Please set `use_flow_sigmas=True` during scheduler initialization." + ) + num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + is_timesteps_provided = timesteps is not None + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 - if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" - self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) @@ -342,7 +363,8 @@ def set_timesteps( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() @@ -386,10 +408,21 @@ def set_timesteps( ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_flow_sigmas: - alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) - sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() - timesteps = (sigmas * self.config.num_train_timesteps).copy() + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1] + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + eps = 1e-6 + if np.fabs(sigmas[0] - 1) < eps: + sigmas[0] -= ( + eps # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update + ) + if not is_timesteps_provided: + timesteps = (sigmas * self.config.num_train_timesteps).copy() if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] elif self.config.final_sigmas_type == "zero": @@ -429,6 +462,43 @@ def set_timesteps( self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """