Skip to content

Commit cc849e2

Browse files
committed
Notes, sana schedule, scale_noise->add_noise
1 parent f12841c commit cc849e2

File tree

6 files changed

+68
-35
lines changed

6 files changed

+68
-35
lines changed

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,9 @@ def prepare_latents(
562562
image_latents = torch.cat([image_latents], dim=0)
563563

564564
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
565-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
565+
# NOTE: `scale_noise` changed to `add_noise`
566+
# the signature is `noise`, `timestep` instead of `timestep`, `noise`
567+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
566568
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
567569
return latents, latent_image_ids
568570

src/diffusers/schedulers/schedules/beta_schedule.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
import math
12
from typing import List, Optional, Union
23

3-
import math
44
import numpy as np
55
import torch
66

7-
from ...configuration_utils import ConfigMixin, register_to_config
87
from ..sigmas.beta_sigmas import BetaSigmas
98
from ..sigmas.exponential_sigmas import ExponentialSigmas
109
from ..sigmas.karras_sigmas import KarrasSigmas
1110

11+
1212
def betas_for_alpha_bar(
1313
num_diffusion_timesteps,
1414
max_beta=0.999,
@@ -52,6 +52,7 @@ def alpha_bar_fn(t):
5252
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
5353
return torch.tensor(betas, dtype=torch.float32)
5454

55+
5556
def rescale_zero_terminal_snr(betas):
5657
"""
5758
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
@@ -89,7 +90,6 @@ def rescale_zero_terminal_snr(betas):
8990

9091

9192
class BetaSchedule:
92-
9393
scale_model_input = True
9494

9595
def __init__(
@@ -132,7 +132,7 @@ def __init__(
132132
# Close to 0 without being 0 so first sigma is not inf
133133
# FP16 smallest positive subnormal works well here
134134
self.alphas_cumprod[-1] = 2**-24
135-
135+
136136
self.num_train_timesteps = num_train_timesteps
137137
self.beta_start = beta_start
138138
self.beta_end = beta_end
@@ -181,6 +181,7 @@ def __call__(
181181
):
182182
if sigmas is not None:
183183
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
184+
# NOTE: current usage is **with** `sigma_last` - different than FlowMatch.
184185
sigmas = np.array(sigmas).astype(np.float32)
185186
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
186187

@@ -190,9 +191,9 @@ def __call__(
190191
else:
191192
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
192193
if self.timestep_spacing == "linspace":
193-
timesteps = np.linspace(
194-
0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
195-
)[::-1].copy()
194+
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
195+
::-1
196+
].copy()
196197
elif self.timestep_spacing == "leading":
197198
step_ratio = self.num_train_timesteps // num_inference_steps
198199
# creates integer timesteps by multiplying by ratio
@@ -205,9 +206,7 @@ def __call__(
205206
step_ratio = self.num_train_timesteps / num_inference_steps
206207
# creates integer timesteps by multiplying by ratio
207208
# casting to int to avoid issues when num_inference_step is power of 3
208-
timesteps = (
209-
(np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
210-
)
209+
timesteps = (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
211210
timesteps -= 1
212211
else:
213212
raise ValueError(
@@ -240,13 +239,13 @@ def __call__(
240239
)
241240

242241
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
243-
242+
244243
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
245244

246245
# TODO: Support the full EDM scalings for all prediction types and timestep types
247246
if self.timestep_type == "continuous" and self.prediction_type == "v_prediction":
248247
timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
249248
else:
250249
timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
251-
250+
252251
return sigmas, timesteps

src/diffusers/schedulers/schedules/flow_schedule.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1+
import math
12
from typing import List, Optional, Union
23

3-
import math
44
import numpy as np
55
import torch
66

7-
from ...configuration_utils import ConfigMixin, register_to_config
87
from ..sigmas.beta_sigmas import BetaSigmas
98
from ..sigmas.exponential_sigmas import ExponentialSigmas
109
from ..sigmas.karras_sigmas import KarrasSigmas
1110

12-
class FlowMatchSD3:
13-
14-
def _sigma_to_t(self, sigma):
15-
return sigma * self.num_train_timesteps
1611

17-
def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: float, use_dynamic_shifting: bool = False, **kwargs):
18-
self.num_train_timesteps = num_train_timesteps
12+
class FlowMatchSD3:
13+
def __call__(
14+
self,
15+
num_inference_steps: int,
16+
num_train_timesteps: int,
17+
shift: float,
18+
use_dynamic_shifting: bool = False,
19+
**kwargs,
20+
) -> np.ndarray:
21+
"""
22+
This is different to others that directly calculate `sigmas`.
23+
It needs `sigma_min` and `sigma_max` after shift
24+
https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L89-L95
25+
Then we calculate `sigmas` from that `sigma_min` and `sigma_max`.
26+
https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L238-L240
27+
Shifting happens again after (outside of this).
28+
https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L248-L251
29+
"""
1930
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
2031
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
2132

@@ -25,18 +36,20 @@ def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: fl
2536
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
2637
sigma_min = sigmas[-1].item()
2738
sigma_max = sigmas[0].item()
28-
timesteps = np.linspace(
29-
self._sigma_to_t(sigma_max), self._sigma_to_t(sigma_min), num_inference_steps
30-
)
39+
timesteps = np.linspace(sigma_max * num_train_timesteps, sigma_min * num_train_timesteps, num_inference_steps)
3140
sigmas = timesteps / num_train_timesteps
3241
return sigmas
3342

43+
3444
class FlowMatchFlux:
35-
def __call__(self, num_inference_steps: int, **kwargs):
45+
def __call__(self, num_inference_steps: int, **kwargs) -> np.ndarray:
3646
return np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
3747

48+
3849
class FlowMatchLinearQuadratic:
39-
def __call__(self, num_inference_steps: int, threshold_noise: float = 0.25, linear_steps: Optional[int] = None, **kwargs):
50+
def __call__(
51+
self, num_inference_steps: int, threshold_noise: float = 0.25, linear_steps: Optional[int] = None, **kwargs
52+
) -> np.ndarray:
4053
if linear_steps is None:
4154
linear_steps = num_inference_steps // 2
4255
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
@@ -49,22 +62,33 @@ def __call__(self, num_inference_steps: int, threshold_noise: float = 0.25, line
4962
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_inference_steps)
5063
]
5164
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
52-
sigma_schedule = [1.0 - x for x in sigma_schedule]
65+
sigma_schedule = np.array([1.0 - x for x in sigma_schedule]).astype(np.float32)
5366
return sigma_schedule
5467

68+
5569
class FlowMatchHunyuanVideo:
56-
def __call__(self, num_inference_steps: int, **kwargs):
70+
def __call__(self, num_inference_steps: int, **kwargs) -> np.ndarray:
5771
return np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].copy()
5872

73+
74+
class FlowMatchSANA:
75+
def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: float, **kwargs) -> np.ndarray:
76+
alphas = np.linspace(1, 1 / num_train_timesteps, num_inference_steps + 1)
77+
sigmas = 1.0 - alphas
78+
sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
79+
return sigmas
80+
81+
5982
BASE_SCHEDULE_MAP = {
6083
"FlowMatchHunyuanVideo": FlowMatchHunyuanVideo,
6184
"FlowMatchLinearQuadratic": FlowMatchLinearQuadratic,
6285
"FlowMatchFlux": FlowMatchFlux,
6386
"FlowMatchSD3": FlowMatchSD3,
87+
"FlowMatchSANA": FlowMatchSANA,
6488
}
6589

66-
class FlowMatchSchedule:
6790

91+
class FlowMatchSchedule:
6892
scale_model_input = False
6993

7094
base_schedules = BASE_SCHEDULE_MAP
@@ -145,7 +169,7 @@ def __call__(
145169
):
146170
shift = shift or self.shift
147171
if self.use_dynamic_shifting and mu is None:
148-
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
172+
raise ValueError("You have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
149173

150174
if sigmas is None:
151175
sigmas = self.base_schedule(
@@ -155,9 +179,8 @@ def __call__(
155179
use_dynamic_shifting=self.use_dynamic_shifting,
156180
)
157181
else:
182+
# NOTE: current usage is **without** `sigma_last` - different than BetaSchedule
158183
sigmas = np.array(sigmas).astype(np.float32)
159-
num_inference_steps = len(sigmas)
160-
self.num_inference_steps = num_inference_steps
161184

162185
if self.use_dynamic_shifting:
163186
sigmas = self.time_shift(mu, 1.0, sigmas)

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,10 @@ def add_noise(
384384
while len(sigma.shape) < len(original_samples.shape):
385385
sigma = sigma.unsqueeze(-1)
386386

387-
noisy_samples = original_samples + noise * sigma
387+
if self._schedule.__class__.__name__ == "FlowMatchSchedule":
388+
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
389+
else:
390+
noisy_samples = original_samples + noise * sigma
388391
return noisy_samples
389392

390393
def __len__(self):

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,10 @@ def add_noise(
413413
while len(sigma.shape) < len(original_samples.shape):
414414
sigma = sigma.unsqueeze(-1)
415415

416-
noisy_samples = original_samples + noise * sigma
416+
if self._schedule.__class__.__name__ == "FlowMatchSchedule":
417+
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
418+
else:
419+
noisy_samples = original_samples + noise * sigma
417420
return noisy_samples
418421

419422
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,10 @@ def add_noise(
401401
while len(sigma.shape) < len(original_samples.shape):
402402
sigma = sigma.unsqueeze(-1)
403403

404-
noisy_samples = original_samples + noise * sigma
404+
if self._schedule.__class__.__name__ == "FlowMatchSchedule":
405+
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
406+
else:
407+
noisy_samples = original_samples + noise * sigma
405408
return noisy_samples
406409

407410
def __len__(self):

0 commit comments

Comments
 (0)