@@ -153,6 +153,8 @@ def __init__(
153153 flow_shift : Optional [float ] = 1.0 ,
154154 timestep_spacing : str = "linspace" ,
155155 steps_offset : int = 0 ,
156+ use_dynamic_shifting : bool = False ,
157+ time_shift_type : str = "exponential" ,
156158 ):
157159 if self .config .use_beta_sigmas and not is_scipy_available ():
158160 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -232,7 +234,7 @@ def set_begin_index(self, begin_index: int = 0):
232234 """
233235 self ._begin_index = begin_index
234236
235- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
237+ def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None , mu : Optional [ float ] = None ):
236238 """
237239 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238240
@@ -242,6 +244,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
242244 device (`str` or `torch.device`, *optional*):
243245 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
244246 """
247+ if mu is not None :
248+ assert self .config .use_dynamic_shifting and self .config .time_shift_type == 'exponential'
249+ self .config .flow_shift = np .exp (mu )
245250 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
246251 if self .config .timestep_spacing == "linspace" :
247252 timesteps = (
0 commit comments