@@ -281,6 +281,16 @@ def do_classifier_free_guidance(self):
281281 def num_timesteps (self ):
282282 return self ._num_timesteps
283283
284+ def get_timestep_ratio_conditioning (self , t , alphas_cumprod ):
285+ s = torch .tensor ([0.008 ])
286+ clamp_range = [0 , 1 ]
287+ min_var = torch .cos (s / (1 + s ) * torch .pi * 0.5 ) ** 2
288+ var = alphas_cumprod [t ]
289+ var = var .clamp (* clamp_range )
290+ s , min_var = s .to (var .device ), min_var .to (var .device )
291+ ratio = (((var * min_var ) ** 0.5 ).acos () / (torch .pi * 0.5 )) * (1 + s ) - s
292+ return ratio
293+
284294 @torch .no_grad ()
285295 @replace_example_docstring (EXAMPLE_DOC_STRING )
286296 def __call__ (
@@ -434,10 +444,30 @@ def __call__(
434444 batch_size , image_embeddings , num_images_per_prompt , dtype , device , generator , latents , self .scheduler
435445 )
436446
447+ if isinstance (self .scheduler , DDPMWuerstchenScheduler ):
448+ timesteps = timesteps [:- 1 ]
449+ else :
450+ if hasattr (self .scheduler .config , "clip_sample" ) and self .scheduler .config .clip_sample :
451+ self .scheduler .config .clip_sample = False # disample sample clipping
452+ logger .warning (" set `clip_sample` to be False" )
453+
437454 # 6. Run denoising loop
438- self ._num_timesteps = len (timesteps [:- 1 ])
439- for i , t in enumerate (self .progress_bar (timesteps [:- 1 ])):
440- timestep_ratio = t .expand (latents .size (0 )).to (dtype )
455+ if hasattr (self .scheduler , "betas" ):
456+ alphas = 1.0 - self .scheduler .betas
457+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
458+ else :
459+ alphas_cumprod = []
460+
461+ self ._num_timesteps = len (timesteps )
462+ for i , t in enumerate (self .progress_bar (timesteps )):
463+ if not isinstance (self .scheduler , DDPMWuerstchenScheduler ):
464+ if len (alphas_cumprod ) > 0 :
465+ timestep_ratio = self .get_timestep_ratio_conditioning (t .long ().cpu (), alphas_cumprod )
466+ timestep_ratio = timestep_ratio .expand (latents .size (0 )).to (dtype ).to (device )
467+ else :
468+ timestep_ratio = t .float ().div (self .scheduler .timesteps [- 1 ]).expand (latents .size (0 )).to (dtype )
469+ else :
470+ timestep_ratio = t .expand (latents .size (0 )).to (dtype )
441471
442472 # 7. Denoise latents
443473 predicted_latents = self .decoder (
@@ -454,6 +484,8 @@ def __call__(
454484 predicted_latents = torch .lerp (predicted_latents_uncond , predicted_latents_text , self .guidance_scale )
455485
456486 # 9. Renoise latents to next timestep
487+ if not isinstance (self .scheduler , DDPMWuerstchenScheduler ):
488+ timestep_ratio = t
457489 latents = self .scheduler .step (
458490 model_output = predicted_latents ,
459491 timestep = timestep_ratio ,
0 commit comments