@@ -245,7 +245,7 @@ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
245245
246246    def  set_timesteps (
247247        self ,
248-         num_inference_steps : int  =  None ,
248+         num_inference_steps : Optional [ int ]  =  None ,
249249        device : Union [str , torch .device ] =  None ,
250250        sigmas : Optional [List [float ]] =  None ,
251251        mu : Optional [float ] =  None ,
@@ -255,7 +255,7 @@ def set_timesteps(
255255        Sets the discrete timesteps used for the diffusion chain (to be run before inference). 
256256
257257        Args: 
258-             num_inference_steps (`int`): 
258+             num_inference_steps (`int`, *optional* ): 
259259                The number of diffusion steps used when generating samples with a pre-trained model. 
260260            device (`str` or `torch.device`, *optional*): 
261261                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 
@@ -270,22 +270,40 @@ def set_timesteps(
270270                automatically. 
271271        """ 
272272        if  self .config .use_dynamic_shifting  and  mu  is  None :
273-             raise  ValueError (" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" )
273+             raise  ValueError ("`mu` must be passed when `use_dynamic_shifting` is set to be `True`" )
274+ 
275+         if  sigmas  is  not None  and  timesteps  is  not None :
276+             if  len (sigmas ) !=  len (timesteps ):
277+                 raise  ValueError ("`sigmas` and `timesteps` should have the same length" )
278+ 
279+         if  num_inference_steps  is  not None :
280+             if  (sigmas  is  not None  and  len (sigmas ) !=  num_inference_steps ) or  (
281+                 timesteps  is  not None  and  len (timesteps ) !=  num_inference_steps 
282+             ):
283+                 raise  ValueError (
284+                     "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" 
285+                 )
286+         else :
287+             num_inference_steps  =  len (sigmas ) if  sigmas  is  not None  else  len (timesteps )
274288
275289        self .num_inference_steps  =  num_inference_steps 
276290
277291        # 1. Prepare default sigmas 
278292        is_timesteps_provided  =  timesteps  is  not None 
293+ 
294+         if  is_timesteps_provided :
295+             timesteps  =  np .array (timesteps ).astype (np .float32 )
296+ 
279297        if  sigmas  is  None :
280298            if  timesteps  is  None :
281299                timesteps  =  np .linspace (
282300                    self ._sigma_to_t (self .sigma_max ), self ._sigma_to_t (self .sigma_min ), num_inference_steps 
283301                )
284-             else :
285-                 timesteps  =  np .array (timesteps ).astype (np .float32 )
286302            sigmas  =  timesteps  /  self .config .num_train_timesteps 
303+             num_inference_steps  =  len (sigmas )
287304        else :
288305            sigmas  =  np .array (sigmas ).astype (np .float32 )
306+             num_inference_steps  =  len (sigmas )
289307
290308        # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of 
291309        #    "exponential" or "linear" type is applied 
0 commit comments