@@ -394,6 +394,24 @@ def encode_prompt(
394394
395395        return  prompt_embeds , prompt_attention_mask 
396396
397+     # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 
398+     def  prepare_extra_step_kwargs (self , generator , eta ):
399+         # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 
400+         # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 
401+         # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 
402+         # and should be between [0, 1] 
403+ 
404+         accepts_eta  =  "eta"  in  set (inspect .signature (self .scheduler .step ).parameters .keys ())
405+         extra_step_kwargs  =  {}
406+         if  accepts_eta :
407+             extra_step_kwargs ["eta" ] =  eta 
408+ 
409+         # check if the scheduler accepts generator 
410+         accepts_generator  =  "generator"  in  set (inspect .signature (self .scheduler .step ).parameters .keys ())
411+         if  accepts_generator :
412+             extra_step_kwargs ["generator" ] =  generator 
413+         return  extra_step_kwargs 
414+ 
397415    def  check_inputs (
398416        self ,
399417        prompt ,
@@ -835,10 +853,11 @@ def __call__(
835853        guidance  =  guidance .expand (latents .shape [0 ]).to (prompt_embeds .dtype )
836854        guidance  =  guidance  *  self .transformer .config .guidance_embeds_scale 
837855
838-         # YiYi  TODO: refactor this  
839-         timesteps  =  timesteps [: - 1 ] 
856+         # 6. Prepare extra step kwargs.  TODO: Logic should ideally just be moved out of the pipeline  
857+         extra_step_kwargs  =  self . prepare_extra_step_kwargs ( generator ,  eta ) 
840858
841859        # 7. Denoising loop 
860+         timesteps  =  timesteps [:- 1 ]
842861        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
843862        self ._num_timesteps  =  len (timesteps )
844863
0 commit comments