@@ -332,29 +332,47 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
332332 )
333333
334334 sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
335- log_sigmas = np .log (sigmas )
336- if self .config .final_sigmas_type == "sigma_min" :
337- sigma_last = sigmas [- 1 ]
338- elif self .config .final_sigmas_type == "zero" :
339- sigma_last = 0
340- else :
341- raise ValueError (
342- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
343- )
344335 if self .config .use_karras_sigmas :
336+ log_sigmas = np .log (sigmas )
345337 sigmas = np .flip (sigmas ).copy ()
346338 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
347339 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
340+ if self .config .final_sigmas_type == "sigma_min" :
341+ sigma_last = sigmas [- 1 ]
342+ elif self .config .final_sigmas_type == "zero" :
343+ sigma_last = 0
344+ else :
345+ raise ValueError (
346+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
347+ )
348348 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
349349 elif self .config .use_exponential_sigmas :
350+ log_sigmas = np .log (sigmas )
350351 sigmas = np .flip (sigmas ).copy ()
351352 sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
352353 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
354+ if self .config .final_sigmas_type == "sigma_min" :
355+ sigma_last = sigmas [- 1 ]
356+ elif self .config .final_sigmas_type == "zero" :
357+ sigma_last = 0
358+ else :
359+ raise ValueError (
360+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
361+ )
353362 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
354363 elif self .config .use_beta_sigmas :
364+ log_sigmas = np .log (sigmas )
355365 sigmas = np .flip (sigmas ).copy ()
356366 sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
357367 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
368+ if self .config .final_sigmas_type == "sigma_min" :
369+ sigma_last = sigmas [- 1 ]
370+ elif self .config .final_sigmas_type == "zero" :
371+ sigma_last = 0
372+ else :
373+ raise ValueError (
374+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
375+ )
358376 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
359377 else :
360378 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
0 commit comments