diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index cbb27e5fad63..575423ee80e7 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): Whether to use beta sigmas for step sizes in the noise schedule during sampling. time_shift_type (`str`, defaults to "exponential"): The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. """ _compatibles = [] @@ -101,6 +103,7 @@ def __init__( use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, time_shift_type: str = "exponential", + stochastic_sampling: bool = False, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -437,13 +440,25 @@ def step( lower_mask = sigmas < per_token_sigmas[None] - 1e-6 lower_sigmas = lower_mask * sigmas lower_sigmas, _ = lower_sigmas.max(dim=0) - dt = (per_token_sigmas - lower_sigmas)[..., None] + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma else: - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] + sigma_idx = self.step_index + sigma = self.sigmas[sigma_idx] + sigma_next = self.sigmas[sigma_idx + 1] + + current_sigma = sigma + next_sigma = sigma_next dt = sigma_next - sigma - prev_sample = sample + dt * model_output + if self.config.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output # upon completion increase step index by one self._step_index += 1