Skip to content

Commit d304b1f

Browse files
committed
Fix unipc
1 parent 01c6f31 commit d304b1f

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,29 +332,47 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
332332
)
333333

334334
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
335-
log_sigmas = np.log(sigmas)
336-
if self.config.final_sigmas_type == "sigma_min":
337-
sigma_last = sigmas[-1]
338-
elif self.config.final_sigmas_type == "zero":
339-
sigma_last = 0
340-
else:
341-
raise ValueError(
342-
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
343-
)
344335
if self.config.use_karras_sigmas:
336+
log_sigmas = np.log(sigmas)
345337
sigmas = np.flip(sigmas).copy()
346338
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
347339
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
340+
if self.config.final_sigmas_type == "sigma_min":
341+
sigma_last = sigmas[-1]
342+
elif self.config.final_sigmas_type == "zero":
343+
sigma_last = 0
344+
else:
345+
raise ValueError(
346+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
347+
)
348348
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
349349
elif self.config.use_exponential_sigmas:
350+
log_sigmas = np.log(sigmas)
350351
sigmas = np.flip(sigmas).copy()
351352
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
352353
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
354+
if self.config.final_sigmas_type == "sigma_min":
355+
sigma_last = sigmas[-1]
356+
elif self.config.final_sigmas_type == "zero":
357+
sigma_last = 0
358+
else:
359+
raise ValueError(
360+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
361+
)
353362
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
354363
elif self.config.use_beta_sigmas:
364+
log_sigmas = np.log(sigmas)
355365
sigmas = np.flip(sigmas).copy()
356366
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
357367
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
368+
if self.config.final_sigmas_type == "sigma_min":
369+
sigma_last = sigmas[-1]
370+
elif self.config.final_sigmas_type == "zero":
371+
sigma_last = 0
372+
else:
373+
raise ValueError(
374+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
375+
)
358376
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
359377
else:
360378
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

0 commit comments

Comments
 (0)