-
Couldn't load subscription status.
- Fork 6.4k
Add stochastic sampling to FlowMatchEulerDiscreteScheduler #11369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
690adb5
f87956e
9edc5be
32d9aef
9c35a89
25bc77d
ff1012f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
| Whether to use beta sigmas for step sizes in the noise schedule during sampling. | ||
| time_shift_type (`str`, defaults to "exponential"): | ||
| The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". | ||
| stochastic_sampling (`bool`, defaults to False): | ||
| Whether to use stochastic sampling. | ||
| """ | ||
|
|
||
| _compatibles = [] | ||
|
|
@@ -101,6 +103,7 @@ def __init__( | |
| use_exponential_sigmas: Optional[bool] = False, | ||
| use_beta_sigmas: Optional[bool] = False, | ||
| time_shift_type: str = "exponential", | ||
| stochastic_sampling: bool = False, | ||
| ): | ||
| if self.config.use_beta_sigmas and not is_scipy_available(): | ||
| raise ImportError("Make sure to install scipy if you want to use beta sigmas.") | ||
|
|
@@ -378,6 +381,7 @@ def step( | |
| s_noise: float = 1.0, | ||
| generator: Optional[torch.Generator] = None, | ||
| per_token_timesteps: Optional[torch.Tensor] = None, | ||
| stochastic_sampling: Optional[bool] = None, | ||
| return_dict: bool = True, | ||
| ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: | ||
| """ | ||
|
|
@@ -400,6 +404,8 @@ def step( | |
| A random number generator. | ||
| per_token_timesteps (`torch.Tensor`, *optional*): | ||
| The timesteps for each token in the sample. | ||
| stochastic_sampling (`bool`, *optional*): | ||
| Whether to use stochastic sampling. If None, defaults to the value set in the scheduler's config. | ||
apolinario marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return_dict (`bool`): | ||
| Whether or not to return a | ||
| [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. | ||
|
|
@@ -437,13 +443,28 @@ def step( | |
| lower_mask = sigmas < per_token_sigmas[None] - 1e-6 | ||
| lower_sigmas = lower_mask * sigmas | ||
| lower_sigmas, _ = lower_sigmas.max(dim=0) | ||
| dt = (per_token_sigmas - lower_sigmas)[..., None] | ||
|
|
||
| current_sigma = per_token_sigmas[..., None] | ||
| next_sigma = lower_sigmas[..., None] | ||
| dt = next_sigma - current_sigma # Equivalent to sigma_next - sigma | ||
|
||
| else: | ||
| sigma = self.sigmas[self.step_index] | ||
| sigma_next = self.sigmas[self.step_index + 1] | ||
| sigma_idx = self.step_index | ||
| sigma = self.sigmas[sigma_idx] | ||
| sigma_next = self.sigmas[sigma_idx + 1] | ||
|
|
||
| current_sigma = sigma | ||
| next_sigma = sigma_next | ||
| dt = sigma_next - sigma | ||
|
|
||
| prev_sample = sample + dt * model_output | ||
| # Determine whether to use stochastic sampling for this step | ||
| use_stochastic = stochastic_sampling if stochastic_sampling is not None else self.config.stochastic_sampling | ||
|
||
|
|
||
| if use_stochastic: | ||
| x0 = sample - current_sigma * model_output | ||
| noise = torch.randn_like(sample) | ||
| prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise | ||
| else: | ||
| prev_sample = sample + dt * model_output | ||
|
|
||
| # upon completion increase step index by one | ||
| self._step_index += 1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.