@@ -689,7 +689,7 @@ def __call__(
689689 height : Optional [int ] = None ,
690690 width : Optional [int ] = None ,
691691 num_inference_steps : int = 50 ,
692- timesteps : List [int ] = None ,
692+ sigmas : Optional [ List [float ] ] = None ,
693693 guidance_scale : float = 30.0 ,
694694 num_images_per_prompt : Optional [int ] = 1 ,
695695 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
@@ -735,10 +735,10 @@ def __call__(
735735 num_inference_steps (`int`, *optional*, defaults to 50):
736736 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
737737 expense of slower inference.
738- timesteps (`List[int ]`, *optional*):
739- Custom timesteps to use for the denoising process with schedulers which support a `timesteps ` argument
740- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
741- passed will be used. Must be in descending order .
738+ sigmas (`List[float ]`, *optional*):
739+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas ` argument in
740+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
741+ will be used.
742742 guidance_scale (`float`, *optional*, defaults to 7.0):
743743 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
744744 `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -878,7 +878,7 @@ def __call__(
878878 masked_image_latents = torch .cat ((masked_image_latents , mask ), dim = - 1 )
879879
880880 # 6. Prepare timesteps
881- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
881+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
882882 image_seq_len = latents .shape [1 ]
883883 mu = calculate_shift (
884884 image_seq_len ,
@@ -891,8 +891,7 @@ def __call__(
891891 self .scheduler ,
892892 num_inference_steps ,
893893 device ,
894- timesteps ,
895- sigmas ,
894+ sigmas = sigmas ,
896895 mu = mu ,
897896 )
898897 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
0 commit comments