@@ -169,6 +169,8 @@ def __init__(
169169 final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
170170 lambda_min_clipped : float = - float ("inf" ),
171171 variance_type : Optional [str ] = None ,
172+ use_dynamic_shifting : bool = False ,
173+ time_shift_type : str = "exponential" ,
172174 ):
173175 if self .config .use_beta_sigmas and not is_scipy_available ():
174176 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -301,6 +303,7 @@ def set_timesteps(
301303 self ,
302304 num_inference_steps : int = None ,
303305 device : Union [str , torch .device ] = None ,
306+ mu : Optional [float ] = None ,
304307 timesteps : Optional [List [int ]] = None ,
305308 ):
306309 """
@@ -316,6 +319,9 @@ def set_timesteps(
316319 timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
317320 passed, `num_inference_steps` must be `None`.
318321 """
322+ if mu is not None :
323+ assert self .config .use_dynamic_shifting and self .config .time_shift_type == 'exponential'
324+ self .config .flow_shift = np .exp (mu )
319325 if num_inference_steps is None and timesteps is None :
320326 raise ValueError ("Must pass exactly one of `num_inference_steps` or `timesteps`." )
321327 if num_inference_steps is not None and timesteps is not None :
0 commit comments