Skip to content

Commit 836ddb4

Browse files
committed
fix tests
1 parent d49e3e9 commit 836ddb4

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ def set_timesteps(
363363
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
364364
)
365365

366-
if sigmas is None:
367-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
368366
if self.config.use_karras_sigmas:
367+
if sigmas is None:
368+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
369369
log_sigmas = np.log(sigmas)
370370
sigmas = np.flip(sigmas).copy()
371371
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -380,6 +380,8 @@ def set_timesteps(
380380
)
381381
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
382382
elif self.config.use_exponential_sigmas:
383+
if sigmas is None:
384+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
383385
log_sigmas = np.log(sigmas)
384386
sigmas = np.flip(sigmas).copy()
385387
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -394,6 +396,8 @@ def set_timesteps(
394396
)
395397
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
396398
elif self.config.use_beta_sigmas:
399+
if sigmas is None:
400+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
397401
log_sigmas = np.log(sigmas)
398402
sigmas = np.flip(sigmas).copy()
399403
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -433,6 +437,8 @@ def set_timesteps(
433437
)
434438
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
435439
else:
440+
if sigmas is None:
441+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
436442
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
437443
if self.config.final_sigmas_type == "sigma_min":
438444
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5

0 commit comments

Comments
 (0)