@@ -403,10 +403,35 @@ def set_timesteps(
403403 num_inference_steps = len (timesteps ) if timesteps is not None else len (sigmas ) - 1
404404 self .num_inference_steps = num_inference_steps
405405
406+ if self .config .final_sigmas_type == "sigma_min" and not self .config .use_flow_match :
407+ sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
408+ elif self .config .final_sigmas_type == "zero" :
409+ sigma_last = 0
410+ elif self .config .invert_sigmas :
411+ sigma_last = 1
412+ else :
413+ raise ValueError (
414+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
415+ )
416+
406417 if sigmas is not None and not self .config .use_flow_match :
407418 log_sigmas = np .log (np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 ))
408419 sigmas = np .array (sigmas ).astype (np .float32 )
409420 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas [:- 1 ]])
421+ elif sigmas is not None and self .config .use_flow_match :
422+ sigmas = np .array (sigmas ).astype (np .float32 )
423+ timesteps = sigmas * self .config .num_train_timesteps
424+
425+ if self .config .use_flow_match :
426+ if self .config .use_dynamic_shifting :
427+ sigmas = self .time_shift (mu , 1.0 , sigmas )
428+ else :
429+ sigmas = self .config .shift * sigmas / (1 + (self .config .shift - 1 ) * sigmas )
430+
431+ if self .config .invert_sigmas :
432+ sigmas = 1.0 - sigmas
433+ timesteps = sigmas * self .config .num_train_timesteps
434+ sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
410435 elif sigmas is None :
411436 if timesteps is not None :
412437 timesteps = np .array (timesteps ).astype (np .float32 )
@@ -482,18 +507,7 @@ def set_timesteps(
482507
483508 if self .config .invert_sigmas :
484509 sigmas = 1.0 - sigmas
485- timesteps = sigmas * self .config .num_train_timesteps
486-
487- if self .config .final_sigmas_type == "sigma_min" :
488- sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
489- elif self .config .final_sigmas_type == "zero" :
490- sigma_last = 0
491- elif self .config .invert_sigmas :
492- sigma_last = 1
493- else :
494- raise ValueError (
495- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
496- )
510+ timesteps = sigmas * self .config .num_train_timesteps
497511
498512 sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
499513
@@ -502,8 +516,6 @@ def set_timesteps(
502516 # TODO: Support the full EDM scalings for all prediction types and timestep types
503517 if self .config .timestep_type == "continuous" and self .config .prediction_type == "v_prediction" :
504518 self .timesteps = torch .Tensor ([0.25 * sigma .log () for sigma in sigmas [:- 1 ]]).to (device = device )
505- elif self .config .use_flow_match :
506- self .timesteps = sigmas [:- 1 ] * self .config .num_train_timesteps
507519 else :
508520 self .timesteps = torch .from_numpy (timesteps .astype (np .float32 )).to (device = device )
509521
0 commit comments