@@ -196,13 +196,21 @@ def __init__(
196196 use_karras_sigmas : Optional [bool ] = False ,
197197 use_exponential_sigmas : Optional [bool ] = False ,
198198 use_beta_sigmas : Optional [bool ] = False ,
199+ use_flow_match : Optional [bool ] = False ,
199200 sigma_min : Optional [float ] = None ,
200201 sigma_max : Optional [float ] = None ,
201202 timestep_spacing : str = "linspace" ,
202203 timestep_type : str = "discrete" , # can be "discrete" or "continuous"
203204 steps_offset : int = 0 ,
204205 rescale_betas_zero_snr : bool = False ,
205206 final_sigmas_type : str = "zero" , # can be "zero" or "sigma_min"
207+ shift : float = 1.0 ,
208+ use_dynamic_shifting = False ,
209+ base_shift : Optional [float ] = 0.5 ,
210+ max_shift : Optional [float ] = 1.15 ,
211+ base_image_seq_len : Optional [int ] = 256 ,
212+ max_image_seq_len : Optional [int ] = 4096 ,
213+ invert_sigmas : bool = False ,
206214 ):
207215 if self .config .use_beta_sigmas and not is_scipy_available ():
208216 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -234,20 +242,39 @@ def __init__(
234242 # FP16 smallest positive subnormal works well here
235243 self .alphas_cumprod [- 1 ] = 2 ** - 24
236244
237- sigmas = (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 ).flip (0 )
238- timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
245+ if use_flow_match :
246+ timestep_offset = 1
247+ else :
248+ timestep_offset = 0
249+
250+ timesteps = np .linspace (
251+ 0 + timestep_offset , num_train_timesteps - 1 + timestep_offset , num_train_timesteps , dtype = float
252+ )[::- 1 ].copy ()
239253 timesteps = torch .from_numpy (timesteps ).to (dtype = torch .float32 )
240254
255+ if use_flow_match :
256+ sigmas = timesteps / num_train_timesteps
257+ if not use_dynamic_shifting :
258+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
259+ sigmas = shift * sigmas / (1 + (shift - 1 ) * sigmas )
260+ else :
261+ sigmas = (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 ).flip (0 )
262+
241263 # setable values
242264 self .num_inference_steps = None
243265
244266 # TODO: Support the full EDM scalings for all prediction types and timestep types
245267 if timestep_type == "continuous" and prediction_type == "v_prediction" :
246268 self .timesteps = torch .Tensor ([0.25 * sigma .log () for sigma in sigmas ])
269+ elif use_flow_match :
270+ self .timesteps = sigmas * num_train_timesteps
247271 else :
248272 self .timesteps = timesteps
249273
250- self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
274+ if not use_flow_match :
275+ sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
276+
277+ self .sigmas = sigmas
251278
252279 self .is_scale_input_called = False
253280 self .use_karras_sigmas = use_karras_sigmas
@@ -257,6 +284,8 @@ def __init__(
257284 self ._step_index = None
258285 self ._begin_index = None
259286 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
287+ self .sigma_min = self .sigmas [- 1 ].item ()
288+ self .sigma_max = self .sigmas [0 ].item ()
260289
261290 @property
262291 def init_noise_sigma (self ):
@@ -322,6 +351,7 @@ def set_timesteps(
322351 device : Union [str , torch .device ] = None ,
323352 timesteps : Optional [List [int ]] = None ,
324353 sigmas : Optional [List [float ]] = None ,
354+ mu : Optional [float ] = None ,
325355 ):
326356 """
327357 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -362,57 +392,81 @@ def set_timesteps(
362392 raise ValueError (
363393 "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
364394 )
395+ if timesteps is not None and self .config .use_flow_match :
396+ # TODO: `timesteps / self.config.num_train_timesteps` to get sigmas?
397+ raise ValueError ("Cannot set `timesteps` with `config.use_flow_match = True`." )
398+
399+ if self .config .use_dynamic_shifting and mu is None :
400+ raise ValueError (" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" )
365401
366402 if num_inference_steps is None :
367403 num_inference_steps = len (timesteps ) if timesteps is not None else len (sigmas ) - 1
368404 self .num_inference_steps = num_inference_steps
369405
370- if sigmas is not None :
406+ if sigmas is not None and not self . config . use_flow_match :
371407 log_sigmas = np .log (np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 ))
372408 sigmas = np .array (sigmas ).astype (np .float32 )
373409 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas [:- 1 ]])
374-
375- else :
410+ elif sigmas is None :
376411 if timesteps is not None :
377412 timesteps = np .array (timesteps ).astype (np .float32 )
378413 else :
379- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
380- if self .config .timestep_spacing == "linspace" :
414+ if self .config .use_flow_match :
381415 timesteps = np .linspace (
382- 0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = np .float32
383- )[::- 1 ].copy ()
384- elif self .config .timestep_spacing == "leading" :
385- step_ratio = self .config .num_train_timesteps // self .num_inference_steps
386- # creates integer timesteps by multiplying by ratio
387- # casting to int to avoid issues when num_inference_step is power of 3
388- timesteps = (
389- (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .float32 )
390- )
391- timesteps += self .config .steps_offset
392- elif self .config .timestep_spacing == "trailing" :
393- step_ratio = self .config .num_train_timesteps / self .num_inference_steps
394- # creates integer timesteps by multiplying by ratio
395- # casting to int to avoid issues when num_inference_step is power of 3
396- timesteps = (
397- (np .arange (self .config .num_train_timesteps , 0 , - step_ratio )).round ().copy ().astype (np .float32 )
416+ self ._sigma_to_t (self .sigma_max ), self ._sigma_to_t (self .sigma_min ), num_inference_steps
398417 )
399- timesteps -= 1
418+ sigmas = timesteps / self . config . num_train_timesteps
400419 else :
401- raise ValueError (
402- f"{ self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
403- )
420+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
421+ if self .config .timestep_spacing == "linspace" :
422+ timesteps = np .linspace (
423+ 0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = np .float32
424+ )[::- 1 ].copy ()
425+ elif self .config .timestep_spacing == "leading" :
426+ step_ratio = self .config .num_train_timesteps // self .num_inference_steps
427+ # creates integer timesteps by multiplying by ratio
428+ # casting to int to avoid issues when num_inference_step is power of 3
429+ timesteps = (
430+ (np .arange (0 , num_inference_steps ) * step_ratio ).round ()[::- 1 ].copy ().astype (np .float32 )
431+ )
432+ timesteps += self .config .steps_offset
433+ elif self .config .timestep_spacing == "trailing" :
434+ step_ratio = self .config .num_train_timesteps / self .num_inference_steps
435+ # creates integer timesteps by multiplying by ratio
436+ # casting to int to avoid issues when num_inference_step is power of 3
437+ timesteps = (
438+ (np .arange (self .config .num_train_timesteps , 0 , - step_ratio ))
439+ .round ()
440+ .copy ()
441+ .astype (np .float32 )
442+ )
443+ timesteps -= 1
444+ else :
445+ raise ValueError (
446+ f"{ self .config .timestep_spacing } is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
447+ )
448+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
449+ if self .config .interpolation_type == "linear" :
450+ sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
451+ elif self .config .interpolation_type == "log_linear" :
452+ sigmas = (
453+ torch .linspace (np .log (sigmas [- 1 ]), np .log (sigmas [0 ]), num_inference_steps + 1 )
454+ .exp ()
455+ .numpy ()
456+ )
457+ else :
458+ raise ValueError (
459+ f"{ self .config .interpolation_type } is not implemented. Please specify interpolation_type to either"
460+ " 'linear' or 'log_linear'"
461+ )
462+
463+ if self .config .use_flow_match :
464+ if self .config .use_dynamic_shifting :
465+ sigmas = self .time_shift (mu , 1.0 , sigmas )
466+ else :
467+ sigmas = self .config .shift * sigmas / (1 + (self .config .shift - 1 ) * sigmas )
404468
405- sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
406469 log_sigmas = np .log (sigmas )
407- if self .config .interpolation_type == "linear" :
408- sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
409- elif self .config .interpolation_type == "log_linear" :
410- sigmas = torch .linspace (np .log (sigmas [- 1 ]), np .log (sigmas [0 ]), num_inference_steps + 1 ).exp ().numpy ()
411- else :
412- raise ValueError (
413- f"{ self .config .interpolation_type } is not implemented. Please specify interpolation_type to either"
414- " 'linear' or 'log_linear'"
415- )
416470
417471 if self .config .use_karras_sigmas :
418472 sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
@@ -426,10 +480,16 @@ def set_timesteps(
426480 sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
427481 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
428482
483+ if self .config .invert_sigmas :
484+ sigmas = 1.0 - sigmas
485+ timesteps = sigmas * self .config .num_train_timesteps
486+
429487 if self .config .final_sigmas_type == "sigma_min" :
430488 sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
431489 elif self .config .final_sigmas_type == "zero" :
432490 sigma_last = 0
491+ elif self .config .invert_sigmas :
492+ sigma_last = 1
433493 else :
434494 raise ValueError (
435495 f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
@@ -442,14 +502,21 @@ def set_timesteps(
442502 # TODO: Support the full EDM scalings for all prediction types and timestep types
443503 if self .config .timestep_type == "continuous" and self .config .prediction_type == "v_prediction" :
444504 self .timesteps = torch .Tensor ([0.25 * sigma .log () for sigma in sigmas [:- 1 ]]).to (device = device )
505+ elif self .config .use_flow_match :
506+ self .timesteps = sigmas * self .config .num_train_timesteps
445507 else :
446508 self .timesteps = torch .from_numpy (timesteps .astype (np .float32 )).to (device = device )
447509
448510 self ._step_index = None
449511 self ._begin_index = None
450512 self .sigmas = sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
451513
452- def _sigma_to_t (self , sigma , log_sigmas ):
514+ def time_shift (self , mu : float , sigma : float , t : torch .Tensor ):
515+ return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
516+
517+ def _sigma_to_t (self , sigma , log_sigmas = None ):
518+ if self .config .use_flow_match :
519+ return sigma * self .config .num_train_timesteps
453520 # get log sigma
454521 log_sigma = np .log (np .maximum (sigma , 1e-10 ))
455522
@@ -622,7 +689,7 @@ def step(
622689 ),
623690 )
624691
625- if not self .is_scale_input_called :
692+ if not self .is_scale_input_called and not self . config . use_flow_match :
626693 logger .warning (
627694 "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
628695 "See `StableDiffusionPipeline` for a usage example."
@@ -663,7 +730,10 @@ def step(
663730 )
664731
665732 # 2. Convert to an ODE derivative
666- derivative = (sample - pred_original_sample ) / sigma_hat
733+ if self .config .use_flow_match :
734+ derivative = model_output
735+ else :
736+ derivative = (sample - pred_original_sample ) / sigma_hat
667737
668738 dt = self .sigmas [self .step_index + 1 ] - sigma_hat
669739
@@ -713,7 +783,10 @@ def add_noise(
713783 while len (sigma .shape ) < len (original_samples .shape ):
714784 sigma = sigma .unsqueeze (- 1 )
715785
716- noisy_samples = original_samples + noise * sigma
786+ if self .config .use_flow_match :
787+ noisy_samples = (1.0 - sigma ) * original_samples + noise * sigma
788+ else :
789+ noisy_samples = original_samples + noise * sigma
717790 return noisy_samples
718791
719792 def get_velocity (self , sample : torch .Tensor , noise : torch .Tensor , timesteps : torch .Tensor ) -> torch .Tensor :
0 commit comments