Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
stochastic_sampling (`bool`, defaults to False):
Whether to use stochastic sampling.
"""

_compatibles = []
Expand All @@ -101,6 +103,7 @@ def __init__(
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
Expand Down Expand Up @@ -437,13 +440,25 @@ def step(
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]

current_sigma = per_token_sigmas[..., None]
next_sigma = lower_sigmas[..., None]
dt = current_sigma - next_sigma
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]

current_sigma = sigma
next_sigma = sigma_next
dt = sigma_next - sigma

prev_sample = sample + dt * model_output
if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output

# upon completion increase step index by one
self._step_index += 1
Expand Down