Skip to content

Commit 4e9119f

Browse files
committed
timesteps, set_timesteps(sigmas=..)
1 parent 98a52db commit 4e9119f

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,14 @@ def __init__(
251251
0 + timestep_offset, num_train_timesteps - 1 + timestep_offset, num_train_timesteps, dtype=float
252252
)[::-1].copy()
253253
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
254+
print(timesteps)
254255

255256
if use_flow_match:
256257
sigmas = timesteps / num_train_timesteps
257258
if not use_dynamic_shifting:
258259
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
259260
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
261+
print(sigmas)
260262
else:
261263
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
262264

@@ -268,6 +270,7 @@ def __init__(
268270
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
269271
elif use_flow_match:
270272
self.timesteps = sigmas * num_train_timesteps
273+
print(self.timesteps)
271274
else:
272275
self.timesteps = timesteps
273276

@@ -407,6 +410,19 @@ def set_timesteps(
407410
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
408411
sigmas = np.array(sigmas).astype(np.float32)
409412
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
413+
elif sigmas is not None and self.config.use_flow_match:
414+
sigmas = np.array(sigmas).astype(np.float32)
415+
timesteps = sigmas * self.config.num_train_timesteps
416+
417+
if self.config.use_flow_match:
418+
if self.config.use_dynamic_shifting:
419+
sigmas = self.time_shift(mu, 1.0, sigmas)
420+
else:
421+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
422+
423+
if self.config.invert_sigmas:
424+
sigmas = 1.0 - sigmas
425+
timesteps = sigmas * self.config.num_train_timesteps
410426
elif sigmas is None:
411427
if timesteps is not None:
412428
timesteps = np.array(timesteps).astype(np.float32)
@@ -482,7 +498,7 @@ def set_timesteps(
482498

483499
if self.config.invert_sigmas:
484500
sigmas = 1.0 - sigmas
485-
timesteps = sigmas * self.config.num_train_timesteps
501+
timesteps = sigmas * self.config.num_train_timesteps
486502

487503
if self.config.final_sigmas_type == "sigma_min":
488504
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -502,8 +518,6 @@ def set_timesteps(
502518
# TODO: Support the full EDM scalings for all prediction types and timestep types
503519
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
504520
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
505-
elif self.config.use_flow_match:
506-
self.timesteps = sigmas[:-1] * self.config.num_train_timesteps
507521
else:
508522
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
509523

0 commit comments

Comments
 (0)