Skip to content

Commit f12841c

Browse files
committed
1 parent ada44e7 commit f12841c

14 files changed

+680
-732
lines changed

src/diffusers/configuration_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,34 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
245245
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
246246
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
247247

248+
# Handle old scheduler configs
249+
if "Scheduler" in cls.__name__ and "schedule_config" not in config:
250+
prediction_type = config.pop("prediction_type", None)
251+
_class_name = config.pop("_class_name", None)
252+
_diffusers_version = config.pop("_diffusers_version", None)
253+
use_karras_sigmas = config.pop("use_karras_sigmas", None)
254+
use_exponential_sigmas = config.pop("use_exponential_sigmas", None)
255+
use_beta_sigmas = config.pop("use_beta_sigmas", None)
256+
if use_karras_sigmas:
257+
sigma_schedule_config = {"class_name": "KarrasSigmas"}
258+
elif use_exponential_sigmas:
259+
sigma_schedule_config = {"class_name": "ExponentialSigmas"}
260+
elif use_beta_sigmas:
261+
sigma_schedule_config = {"class_name": "BetaSigmas"}
262+
else:
263+
sigma_schedule_config = {}
264+
if "beta_schedule" in config:
265+
config.update({"class_name": "BetaSchedule"})
266+
elif "shift" in config:
267+
config.update({"class_name": "FlowMatchSchedule"})
268+
config = {
269+
"_class_name": _class_name,
270+
"_diffusers_version": _diffusers_version,
271+
"prediction_type": prediction_type,
272+
"schedule_config": config,
273+
"sigma_schedule_config": sigma_schedule_config,
274+
}
275+
248276
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
249277

250278
# Allow dtype to be specified on initialization

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import inspect
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

18-
import numpy as np
1918
import torch
2019
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
2120

@@ -699,7 +698,8 @@ def __call__(
699698
)
700699

701700
# 5. Prepare timesteps
702-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
701+
if self.scheduler.schedule.__class__.__name__ != "FlowMatchFlux":
702+
self.scheduler._schedule.set_base_schedule("FlowMatchFlux")
703703
image_seq_len = latents.shape[1]
704704
mu = calculate_shift(
705705
image_seq_len,

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import inspect
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

18-
import numpy as np
1918
import torch
2019
from transformers import T5EncoderModel, T5TokenizerFast
2120

@@ -495,6 +494,7 @@ def __call__(
495494
num_frames: int = 19,
496495
num_inference_steps: int = 64,
497496
timesteps: List[int] = None,
497+
sigmas: List[float] = None,
498498
guidance_scale: float = 4.5,
499499
num_videos_per_prompt: Optional[int] = 1,
500500
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -652,10 +652,8 @@ def __call__(
652652
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
653653

654654
# 5. Prepare timestep
655-
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
656-
threshold_noise = 0.025
657-
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
658-
sigmas = np.array(sigmas)
655+
if self.scheduler.schedule.__class__.__name__ != "FlowMatchLinearQuadratic":
656+
self.scheduler._schedule.set_base_schedule("FlowMatchLinearQuadratic")
659657

660658
timesteps, num_inference_steps = retrieve_timesteps(
661659
self.scheduler,

src/diffusers/schedulers/schedules/__init__.py

Whitespace-only changes.
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
from typing import List, Optional, Union
2+
3+
import math
4+
import numpy as np
5+
import torch
6+
7+
from ...configuration_utils import ConfigMixin, register_to_config
8+
from ..sigmas.beta_sigmas import BetaSigmas
9+
from ..sigmas.exponential_sigmas import ExponentialSigmas
10+
from ..sigmas.karras_sigmas import KarrasSigmas
11+
12+
def betas_for_alpha_bar(
13+
num_diffusion_timesteps,
14+
max_beta=0.999,
15+
alpha_transform_type="cosine",
16+
):
17+
"""
18+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
19+
(1-beta) over time from t = [0,1].
20+
21+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
22+
to that part of the diffusion process.
23+
24+
25+
Args:
26+
num_diffusion_timesteps (`int`): the number of betas to produce.
27+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
28+
prevent singularities.
29+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
30+
Choose from `cosine` or `exp`
31+
32+
Returns:
33+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
34+
"""
35+
if alpha_transform_type == "cosine":
36+
37+
def alpha_bar_fn(t):
38+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
39+
40+
elif alpha_transform_type == "exp":
41+
42+
def alpha_bar_fn(t):
43+
return math.exp(t * -12.0)
44+
45+
else:
46+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
47+
48+
betas = []
49+
for i in range(num_diffusion_timesteps):
50+
t1 = i / num_diffusion_timesteps
51+
t2 = (i + 1) / num_diffusion_timesteps
52+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
53+
return torch.tensor(betas, dtype=torch.float32)
54+
55+
def rescale_zero_terminal_snr(betas):
56+
"""
57+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
58+
59+
60+
Args:
61+
betas (`torch.Tensor`):
62+
the betas that the scheduler is being initialized with.
63+
64+
Returns:
65+
`torch.Tensor`: rescaled betas with zero terminal SNR
66+
"""
67+
# Convert betas to alphas_bar_sqrt
68+
alphas = 1.0 - betas
69+
alphas_cumprod = torch.cumprod(alphas, dim=0)
70+
alphas_bar_sqrt = alphas_cumprod.sqrt()
71+
72+
# Store old values.
73+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
74+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
75+
76+
# Shift so the last timestep is zero.
77+
alphas_bar_sqrt -= alphas_bar_sqrt_T
78+
79+
# Scale so the first timestep is back to the old value.
80+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
81+
82+
# Convert alphas_bar_sqrt to betas
83+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
84+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
85+
alphas = torch.cat([alphas_bar[0:1], alphas])
86+
betas = 1 - alphas
87+
88+
return betas
89+
90+
91+
class BetaSchedule:
92+
93+
scale_model_input = True
94+
95+
def __init__(
96+
self,
97+
num_train_timesteps: int = 1000,
98+
beta_start: float = 0.0001,
99+
beta_end: float = 0.02,
100+
beta_schedule: str = "linear",
101+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
102+
rescale_betas_zero_snr: bool = False,
103+
interpolation_type: str = "linear",
104+
timestep_spacing: str = "linspace",
105+
timestep_type: str = "discrete", # can be "discrete" or "continuous"
106+
steps_offset: int = 0,
107+
sigma_min: Optional[float] = None,
108+
sigma_max: Optional[float] = None,
109+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
110+
**kwargs,
111+
):
112+
if trained_betas is not None:
113+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
114+
elif beta_schedule == "linear":
115+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
116+
elif beta_schedule == "scaled_linear":
117+
# this schedule is very specific to the latent diffusion model.
118+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
119+
elif beta_schedule == "squaredcos_cap_v2":
120+
# Glide cosine schedule
121+
self.betas = betas_for_alpha_bar(num_train_timesteps)
122+
else:
123+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
124+
125+
if rescale_betas_zero_snr:
126+
self.betas = rescale_zero_terminal_snr(self.betas)
127+
128+
self.alphas = 1.0 - self.betas
129+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
130+
131+
if rescale_betas_zero_snr:
132+
# Close to 0 without being 0 so first sigma is not inf
133+
# FP16 smallest positive subnormal works well here
134+
self.alphas_cumprod[-1] = 2**-24
135+
136+
self.num_train_timesteps = num_train_timesteps
137+
self.beta_start = beta_start
138+
self.beta_end = beta_end
139+
self.beta_schedule = beta_schedule
140+
self.trained_betas = trained_betas
141+
self.rescale_betas_zero_snr = rescale_betas_zero_snr
142+
self.interpolation_type = interpolation_type
143+
self.timestep_spacing = timestep_spacing
144+
self.timestep_type = timestep_type
145+
self.steps_offset = steps_offset
146+
self.sigma_min = sigma_min
147+
self.sigma_max = sigma_max
148+
self.final_sigmas_type = final_sigmas_type
149+
150+
def _sigma_to_t(self, sigma, log_sigmas):
151+
# get log sigma
152+
log_sigma = np.log(np.maximum(sigma, 1e-10))
153+
154+
# get distribution
155+
dists = log_sigma - log_sigmas[:, np.newaxis]
156+
157+
# get sigmas range
158+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
159+
high_idx = low_idx + 1
160+
161+
low = log_sigmas[low_idx]
162+
high = log_sigmas[high_idx]
163+
164+
# interpolate sigmas
165+
w = (low - log_sigma) / (low - high)
166+
w = np.clip(w, 0, 1)
167+
168+
# transform interpolation to time range
169+
t = (1 - w) * low_idx + w * high_idx
170+
t = t.reshape(sigma.shape)
171+
return t
172+
173+
def __call__(
174+
self,
175+
num_inference_steps: int = None,
176+
device: Union[str, torch.device] = None,
177+
timesteps: Optional[List[int]] = None,
178+
sigmas: Optional[List[float]] = None,
179+
sigma_schedule: Optional[Union[KarrasSigmas, ExponentialSigmas, BetaSigmas]] = None,
180+
**kwargs,
181+
):
182+
if sigmas is not None:
183+
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
184+
sigmas = np.array(sigmas).astype(np.float32)
185+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
186+
187+
else:
188+
if timesteps is not None:
189+
timesteps = np.array(timesteps).astype(np.float32)
190+
else:
191+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
192+
if self.timestep_spacing == "linspace":
193+
timesteps = np.linspace(
194+
0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
195+
)[::-1].copy()
196+
elif self.timestep_spacing == "leading":
197+
step_ratio = self.num_train_timesteps // num_inference_steps
198+
# creates integer timesteps by multiplying by ratio
199+
# casting to int to avoid issues when num_inference_step is power of 3
200+
timesteps = (
201+
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
202+
)
203+
timesteps += self.steps_offset
204+
elif self.timestep_spacing == "trailing":
205+
step_ratio = self.num_train_timesteps / num_inference_steps
206+
# creates integer timesteps by multiplying by ratio
207+
# casting to int to avoid issues when num_inference_step is power of 3
208+
timesteps = (
209+
(np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
210+
)
211+
timesteps -= 1
212+
else:
213+
raise ValueError(
214+
f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
215+
)
216+
217+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
218+
log_sigmas = np.log(sigmas)
219+
if self.interpolation_type == "linear":
220+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
221+
elif self.interpolation_type == "log_linear":
222+
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
223+
else:
224+
raise ValueError(
225+
f"{self.interpolation_type} is not implemented. Please specify interpolation_type to either"
226+
" 'linear' or 'log_linear'"
227+
)
228+
229+
if sigma_schedule is not None:
230+
sigmas = sigma_schedule(sigmas)
231+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
232+
233+
if self.final_sigmas_type == "sigma_min":
234+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
235+
elif self.final_sigmas_type == "zero":
236+
sigma_last = 0
237+
else:
238+
raise ValueError(
239+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.final_sigmas_type}"
240+
)
241+
242+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
243+
244+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
245+
246+
# TODO: Support the full EDM scalings for all prediction types and timestep types
247+
if self.timestep_type == "continuous" and self.prediction_type == "v_prediction":
248+
timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
249+
else:
250+
timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
251+
252+
return sigmas, timesteps

0 commit comments

Comments
 (0)