Skip to content

Commit be096b3

Browse files
committed
update
1 parent 03c3f69 commit be096b3

File tree

1 file changed

+79
-9
lines changed

1 file changed

+79
-9
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
rescale_betas_zero_snr: bool = False,
217217
use_dynamic_shifting: bool = False,
218218
time_shift_type: str = "exponential",
219+
shift_terminal: Optional[float] = None,
219220
):
220221
if self.config.use_beta_sigmas and not is_scipy_available():
221222
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -235,6 +236,8 @@ def __init__(
235236
self.betas = betas_for_alpha_bar(num_train_timesteps)
236237
else:
237238
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
239+
if shift_terminal is not None and not use_flow_sigmas:
240+
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
238241

239242
if rescale_betas_zero_snr:
240243
self.betas = rescale_zero_terminal_snr(self.betas)
@@ -303,7 +306,12 @@ def set_begin_index(self, begin_index: int = 0):
303306
self._begin_index = begin_index
304307

305308
def set_timesteps(
306-
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
309+
self,
310+
num_inference_steps: Optional[int] = None,
311+
device: Union[str, torch.device] = None,
312+
mu: Optional[float] = None,
313+
sigmas: Optional[List[float]] = None,
314+
timesteps: Optional[List[float]] = None,
307315
):
308316
"""
309317
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -314,10 +322,23 @@ def set_timesteps(
314322
device (`str` or `torch.device`, *optional*):
315323
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
316324
"""
325+
if self.config.use_dynamic_shifting and mu is None:
326+
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
327+
328+
if sigmas is not None or timesteps is not None:
329+
if not self.config.use_flow_sigmas:
330+
raise ValueError(
331+
"Passing `sigmas` or `timesteps` is only supported when `use_flow_sigmas=True`. "
332+
"Please set `use_flow_sigmas=True` during scheduler initialization."
333+
)
334+
num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
335+
if sigmas is not None and timesteps is not None:
336+
if len(sigmas) != len(timesteps):
337+
raise ValueError("`sigmas` and `timesteps` should have the same length")
338+
339+
is_timesteps_provided = timesteps is not None
340+
317341
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
318-
if mu is not None:
319-
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
320-
self.config.flow_shift = np.exp(mu)
321342
if self.config.timestep_spacing == "linspace":
322343
timesteps = (
323344
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -342,7 +363,8 @@ def set_timesteps(
342363
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
343364
)
344365

345-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
366+
if sigmas is None:
367+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
346368
if self.config.use_karras_sigmas:
347369
log_sigmas = np.log(sigmas)
348370
sigmas = np.flip(sigmas).copy()
@@ -386,10 +408,21 @@ def set_timesteps(
386408
)
387409
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
388410
elif self.config.use_flow_sigmas:
389-
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
390-
sigmas = 1.0 - alphas
391-
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
392-
timesteps = (sigmas * self.config.num_train_timesteps).copy()
411+
if sigmas is None:
412+
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
413+
if self.config.use_dynamic_shifting:
414+
sigmas = self.time_shift(mu, 1.0, sigmas)
415+
else:
416+
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
417+
if self.config.shift_terminal:
418+
sigmas = self.stretch_shift_to_terminal(sigmas)
419+
eps = 1e-6
420+
if np.fabs(sigmas[0] - 1) < eps:
421+
sigmas[0] -= (
422+
eps # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
423+
)
424+
if not is_timesteps_provided:
425+
timesteps = (sigmas * self.config.num_train_timesteps).copy()
393426
if self.config.final_sigmas_type == "sigma_min":
394427
sigma_last = sigmas[-1]
395428
elif self.config.final_sigmas_type == "zero":
@@ -429,6 +462,43 @@ def set_timesteps(
429462
self._begin_index = None
430463
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
431464

465+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
466+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
467+
if self.config.time_shift_type == "exponential":
468+
return self._time_shift_exponential(mu, sigma, t)
469+
elif self.config.time_shift_type == "linear":
470+
return self._time_shift_linear(mu, sigma, t)
471+
472+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
473+
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
474+
r"""
475+
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
476+
value.
477+
478+
Reference:
479+
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
480+
481+
Args:
482+
t (`torch.Tensor`):
483+
A tensor of timesteps to be stretched and shifted.
484+
485+
Returns:
486+
`torch.Tensor`:
487+
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
488+
"""
489+
one_minus_z = 1 - t
490+
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
491+
stretched_t = 1 - (one_minus_z / scale_factor)
492+
return stretched_t
493+
494+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
495+
def _time_shift_exponential(self, mu, sigma, t):
496+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
497+
498+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
499+
def _time_shift_linear(self, mu, sigma, t):
500+
return mu / (mu + (1 / t - 1) ** sigma)
501+
432502
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
433503
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
434504
"""

0 commit comments

Comments
 (0)