@@ -689,7 +689,7 @@ def __call__(
689689        height : Optional [int ] =  None ,
690690        width : Optional [int ] =  None ,
691691        num_inference_steps : int  =  50 ,
692-         timesteps :  List [int ] =  None ,
692+         sigmas :  Optional [ List [float ] ] =  None ,
693693        guidance_scale : float  =  30.0 ,
694694        num_images_per_prompt : Optional [int ] =  1 ,
695695        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
@@ -735,10 +735,10 @@ def __call__(
735735            num_inference_steps (`int`, *optional*, defaults to 50): 
736736                The number of denoising steps. More denoising steps usually lead to a higher quality image at the 
737737                expense of slower inference. 
738-             timesteps  (`List[int ]`, *optional*): 
739-                 Custom timesteps  to use for the denoising process with schedulers which support a `timesteps ` argument 
740-                 in  their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 
741-                 passed  will be used. Must be in descending order . 
738+             sigmas  (`List[float ]`, *optional*): 
739+                 Custom sigmas  to use for the denoising process with schedulers which support a `sigmas ` argument in  
740+                 their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed  
741+                 will be used. 
742742            guidance_scale (`float`, *optional*, defaults to 7.0): 
743743                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 
744744                `guidance_scale` is defined as `w` of equation 2. of [Imagen 
@@ -878,7 +878,7 @@ def __call__(
878878            masked_image_latents  =  torch .cat ((masked_image_latents , mask ), dim = - 1 )
879879
880880        # 6. Prepare timesteps 
881-         sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )
881+         sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )  if   sigmas   is   None   else   sigmas 
882882        image_seq_len  =  latents .shape [1 ]
883883        mu  =  calculate_shift (
884884            image_seq_len ,
@@ -891,8 +891,7 @@ def __call__(
891891            self .scheduler ,
892892            num_inference_steps ,
893893            device ,
894-             timesteps ,
895-             sigmas ,
894+             sigmas = sigmas ,
896895            mu = mu ,
897896        )
898897        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
0 commit comments