Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
stochastic_sampling (`bool`, defaults to False):
Whether to use stochastic sampling.
"""

_compatibles = []
Expand All @@ -101,6 +103,7 @@ def __init__(
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
Expand Down Expand Up @@ -378,6 +381,7 @@ def step(
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
per_token_timesteps: Optional[torch.Tensor] = None,
stochastic_sampling: Optional[bool] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Expand All @@ -400,6 +404,8 @@ def step(
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
stochastic_sampling (`bool`, *optional*):
Whether to use stochastic sampling. If None, defaults to the value set in the scheduler's config.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
Expand Down Expand Up @@ -437,13 +443,28 @@ def step(
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]

current_sigma = per_token_sigmas[..., None]
next_sigma = lower_sigmas[..., None]
dt = next_sigma - current_sigma # Equivalent to sigma_next - sigma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apolinario
here it seems to reversed, no?
before:
dt = (per_token_sigmas - lower_sigmas)[..., None]

now:
dt = ower_sigmas - per_token_sigmas

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]

current_sigma = sigma
next_sigma = sigma_next
dt = sigma_next - sigma

prev_sample = sample + dt * model_output
# Determine whether to use stochastic sampling for this step
use_stochastic = stochastic_sampling if stochastic_sampling is not None else self.config.stochastic_sampling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just have this in config is enough no?


if use_stochastic:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output

# upon completion increase step index by one
self._step_index += 1
Expand Down