Skip to content

Commit e8c9f9d

Browse files
committed
timesteps, set_timesteps(sigmas=..)
1 parent 98a52db commit e8c9f9d

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,35 @@ def set_timesteps(
403403
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
404404
self.num_inference_steps = num_inference_steps
405405

406+
if self.config.final_sigmas_type == "sigma_min" and not self.config.use_flow_match:
407+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
408+
elif self.config.final_sigmas_type == "zero":
409+
sigma_last = 0
410+
elif self.config.invert_sigmas:
411+
sigma_last = 1
412+
else:
413+
raise ValueError(
414+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
415+
)
416+
406417
if sigmas is not None and not self.config.use_flow_match:
407418
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
408419
sigmas = np.array(sigmas).astype(np.float32)
409420
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
421+
elif sigmas is not None and self.config.use_flow_match:
422+
sigmas = np.array(sigmas).astype(np.float32)
423+
timesteps = sigmas * self.config.num_train_timesteps
424+
425+
if self.config.use_flow_match:
426+
if self.config.use_dynamic_shifting:
427+
sigmas = self.time_shift(mu, 1.0, sigmas)
428+
else:
429+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
430+
431+
if self.config.invert_sigmas:
432+
sigmas = 1.0 - sigmas
433+
timesteps = sigmas * self.config.num_train_timesteps
434+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
410435
elif sigmas is None:
411436
if timesteps is not None:
412437
timesteps = np.array(timesteps).astype(np.float32)
@@ -482,18 +507,7 @@ def set_timesteps(
482507

483508
if self.config.invert_sigmas:
484509
sigmas = 1.0 - sigmas
485-
timesteps = sigmas * self.config.num_train_timesteps
486-
487-
if self.config.final_sigmas_type == "sigma_min":
488-
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
489-
elif self.config.final_sigmas_type == "zero":
490-
sigma_last = 0
491-
elif self.config.invert_sigmas:
492-
sigma_last = 1
493-
else:
494-
raise ValueError(
495-
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
496-
)
510+
timesteps = sigmas * self.config.num_train_timesteps
497511

498512
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
499513

@@ -502,8 +516,6 @@ def set_timesteps(
502516
# TODO: Support the full EDM scalings for all prediction types and timestep types
503517
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
504518
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
505-
elif self.config.use_flow_match:
506-
self.timesteps = sigmas[:-1] * self.config.num_train_timesteps
507519
else:
508520
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
509521

0 commit comments

Comments
 (0)