@@ -251,12 +251,14 @@ def __init__(
251251 0 + timestep_offset , num_train_timesteps - 1 + timestep_offset , num_train_timesteps , dtype = float
252252 )[::- 1 ].copy ()
253253 timesteps = torch .from_numpy (timesteps ).to (dtype = torch .float32 )
254+ print (timesteps )
254255
255256 if use_flow_match :
256257 sigmas = timesteps / num_train_timesteps
257258 if not use_dynamic_shifting :
258259 # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
259260 sigmas = shift * sigmas / (1 + (shift - 1 ) * sigmas )
261+ print (sigmas )
260262 else :
261263 sigmas = (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 ).flip (0 )
262264
@@ -268,6 +270,7 @@ def __init__(
268270 self .timesteps = torch .Tensor ([0.25 * sigma .log () for sigma in sigmas ])
269271 elif use_flow_match :
270272 self .timesteps = sigmas * num_train_timesteps
273+ print (self .timesteps )
271274 else :
272275 self .timesteps = timesteps
273276
@@ -407,6 +410,19 @@ def set_timesteps(
407410 log_sigmas = np .log (np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 ))
408411 sigmas = np .array (sigmas ).astype (np .float32 )
409412 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas [:- 1 ]])
413+ elif sigmas is not None and self .config .use_flow_match :
414+ sigmas = np .array (sigmas ).astype (np .float32 )
415+ timesteps = sigmas * self .config .num_train_timesteps
416+
417+ if self .config .use_flow_match :
418+ if self .config .use_dynamic_shifting :
419+ sigmas = self .time_shift (mu , 1.0 , sigmas )
420+ else :
421+ sigmas = self .config .shift * sigmas / (1 + (self .config .shift - 1 ) * sigmas )
422+
423+ if self .config .invert_sigmas :
424+ sigmas = 1.0 - sigmas
425+ timesteps = sigmas * self .config .num_train_timesteps
410426 elif sigmas is None :
411427 if timesteps is not None :
412428 timesteps = np .array (timesteps ).astype (np .float32 )
@@ -482,7 +498,7 @@ def set_timesteps(
482498
483499 if self .config .invert_sigmas :
484500 sigmas = 1.0 - sigmas
485- timesteps = sigmas * self .config .num_train_timesteps
501+ timesteps = sigmas * self .config .num_train_timesteps
486502
487503 if self .config .final_sigmas_type == "sigma_min" :
488504 sigma_last = ((1 - self .alphas_cumprod [0 ]) / self .alphas_cumprod [0 ]) ** 0.5
@@ -502,8 +518,6 @@ def set_timesteps(
502518 # TODO: Support the full EDM scalings for all prediction types and timestep types
503519 if self .config .timestep_type == "continuous" and self .config .prediction_type == "v_prediction" :
504520 self .timesteps = torch .Tensor ([0.25 * sigma .log () for sigma in sigmas [:- 1 ]]).to (device = device )
505- elif self .config .use_flow_match :
506- self .timesteps = sigmas [:- 1 ] * self .config .num_train_timesteps
507521 else :
508522 self .timesteps = torch .from_numpy (timesteps .astype (np .float32 )).to (device = device )
509523
0 commit comments