Skip to content

Commit 230a93c

Browse files
committed
Combine Flow Match Euler into Euler
1 parent 7ac6e28 commit 230a93c

10 files changed

+139
-326
lines changed

examples/community/pipeline_flux_differential_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def prepare_latents(
582582

583583
if latents is None:
584584
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
585-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
585+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
586586
else:
587587
noise = latents.to(device)
588588
latents = noise
@@ -976,8 +976,8 @@ def __call__(
976976

977977
if i < len(timesteps) - 1:
978978
noise_timestep = timesteps[i + 1]
979-
image_latent = self.scheduler.scale_noise(
980-
original_image_latents, torch.tensor([noise_timestep]), noise
979+
image_latent = self.scheduler.add_noise(
980+
original_image_latents, noise, torch.tensor([noise_timestep])
981981
)
982982

983983
# start diff diff

examples/community/pipeline_stable_diffusion_3_differential_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def prepare_latents(
640640
shape = init_latents.shape
641641
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
642642

643-
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
643+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
644644
latents = init_latents.to(device=device, dtype=dtype)
645645

646646
return latents

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def prepare_latents(
579579
image_latents = torch.cat([image_latents], dim=0)
580580

581581
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
582-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
582+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
583583
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
584584
return latents, latent_image_ids
585585

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def prepare_latents(
605605

606606
if latents is None:
607607
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
608-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
608+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
609609
else:
610610
noise = latents.to(device)
611611
latents = noise
@@ -1154,8 +1154,8 @@ def __call__(
11541154

11551155
if i < len(timesteps) - 1:
11561156
noise_timestep = timesteps[i + 1]
1157-
init_latents_proper = self.scheduler.scale_noise(
1158-
init_latents_proper, torch.tensor([noise_timestep]), noise
1157+
init_latents_proper = self.scheduler.add_noise(
1158+
init_latents_proper, noise, torch.tensor([noise_timestep])
11591159
)
11601160

11611161
latents = (1 - init_mask) * init_latents_proper + init_mask * latents

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def prepare_latents(
562562
image_latents = torch.cat([image_latents], dim=0)
563563

564564
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
565-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
565+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
566566
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
567567
return latents, latent_image_ids
568568

src/diffusers/pipelines/flux/pipeline_flux_inpaint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def prepare_latents(
582582

583583
if latents is None:
584584
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
585-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
585+
latents = self.scheduler.add_noise(image_latents, noise, timestep)
586586
else:
587587
noise = latents.to(device)
588588
latents = noise
@@ -978,8 +978,8 @@ def __call__(
978978

979979
if i < len(timesteps) - 1:
980980
noise_timestep = timesteps[i + 1]
981-
init_latents_proper = self.scheduler.scale_noise(
982-
init_latents_proper, torch.tensor([noise_timestep]), noise
981+
init_latents_proper = self.scheduler.add_noise(
982+
init_latents_proper, noise, torch.tensor([noise_timestep])
983983
)
984984

985985
latents = (1 - init_mask) * init_latents_proper + init_mask * latents

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
671671
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
672672

673673
# get latents
674-
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
674+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
675675
latents = init_latents.to(device=device, dtype=dtype)
676676

677677
return latents

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def prepare_latents(
680680
if latents is None:
681681
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
682682
# if strength is 1. then initialise the latents to noise, else initial to image + noise
683-
latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
683+
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
684684
else:
685685
noise = latents.to(device)
686686
latents = noise
@@ -1145,8 +1145,8 @@ def __call__(
11451145

11461146
if i < len(timesteps) - 1:
11471147
noise_timestep = timesteps[i + 1]
1148-
init_latents_proper = self.scheduler.scale_noise(
1149-
init_latents_proper, torch.tensor([noise_timestep]), noise
1148+
init_latents_proper = self.scheduler.add_noise(
1149+
init_latents_proper, noise, torch.tensor([noise_timestep])
11501150
)
11511151

11521152
latents = (1 - init_mask) * init_latents_proper + init_mask * latents

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 115 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,21 @@ def __init__(
196196
use_karras_sigmas: Optional[bool] = False,
197197
use_exponential_sigmas: Optional[bool] = False,
198198
use_beta_sigmas: Optional[bool] = False,
199+
use_flow_match: Optional[bool] = False,
199200
sigma_min: Optional[float] = None,
200201
sigma_max: Optional[float] = None,
201202
timestep_spacing: str = "linspace",
202203
timestep_type: str = "discrete", # can be "discrete" or "continuous"
203204
steps_offset: int = 0,
204205
rescale_betas_zero_snr: bool = False,
205206
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
207+
shift: float = 1.0,
208+
use_dynamic_shifting=False,
209+
base_shift: Optional[float] = 0.5,
210+
max_shift: Optional[float] = 1.15,
211+
base_image_seq_len: Optional[int] = 256,
212+
max_image_seq_len: Optional[int] = 4096,
213+
invert_sigmas: bool = False,
206214
):
207215
if self.config.use_beta_sigmas and not is_scipy_available():
208216
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -234,20 +242,39 @@ def __init__(
234242
# FP16 smallest positive subnormal works well here
235243
self.alphas_cumprod[-1] = 2**-24
236244

237-
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
238-
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
245+
if use_flow_match:
246+
timestep_offset = 1
247+
else:
248+
timestep_offset = 0
249+
250+
timesteps = np.linspace(
251+
0 + timestep_offset, num_train_timesteps - 1 + timestep_offset, num_train_timesteps, dtype=float
252+
)[::-1].copy()
239253
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
240254

255+
if use_flow_match:
256+
sigmas = timesteps / num_train_timesteps
257+
if not use_dynamic_shifting:
258+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
259+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
260+
else:
261+
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
262+
241263
# setable values
242264
self.num_inference_steps = None
243265

244266
# TODO: Support the full EDM scalings for all prediction types and timestep types
245267
if timestep_type == "continuous" and prediction_type == "v_prediction":
246268
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
269+
elif use_flow_match:
270+
self.timesteps = sigmas * num_train_timesteps
247271
else:
248272
self.timesteps = timesteps
249273

250-
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
274+
if not use_flow_match:
275+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
276+
277+
self.sigmas = sigmas
251278

252279
self.is_scale_input_called = False
253280
self.use_karras_sigmas = use_karras_sigmas
@@ -257,6 +284,8 @@ def __init__(
257284
self._step_index = None
258285
self._begin_index = None
259286
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
287+
self.sigma_min = self.sigmas[-1].item()
288+
self.sigma_max = self.sigmas[0].item()
260289

261290
@property
262291
def init_noise_sigma(self):
@@ -322,6 +351,7 @@ def set_timesteps(
322351
device: Union[str, torch.device] = None,
323352
timesteps: Optional[List[int]] = None,
324353
sigmas: Optional[List[float]] = None,
354+
mu: Optional[float] = None,
325355
):
326356
"""
327357
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -362,57 +392,81 @@ def set_timesteps(
362392
raise ValueError(
363393
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
364394
)
395+
if timesteps is not None and self.config.use_flow_match:
396+
# TODO: `timesteps / self.config.num_train_timesteps` to get sigmas?
397+
raise ValueError("Cannot set `timesteps` with `config.use_flow_match = True`.")
398+
399+
if self.config.use_dynamic_shifting and mu is None:
400+
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
365401

366402
if num_inference_steps is None:
367403
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
368404
self.num_inference_steps = num_inference_steps
369405

370-
if sigmas is not None:
406+
if sigmas is not None and not self.config.use_flow_match:
371407
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
372408
sigmas = np.array(sigmas).astype(np.float32)
373409
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
374-
375-
else:
410+
elif sigmas is None:
376411
if timesteps is not None:
377412
timesteps = np.array(timesteps).astype(np.float32)
378413
else:
379-
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
380-
if self.config.timestep_spacing == "linspace":
414+
if self.config.use_flow_match:
381415
timesteps = np.linspace(
382-
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
383-
)[::-1].copy()
384-
elif self.config.timestep_spacing == "leading":
385-
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
386-
# creates integer timesteps by multiplying by ratio
387-
# casting to int to avoid issues when num_inference_step is power of 3
388-
timesteps = (
389-
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
390-
)
391-
timesteps += self.config.steps_offset
392-
elif self.config.timestep_spacing == "trailing":
393-
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
394-
# creates integer timesteps by multiplying by ratio
395-
# casting to int to avoid issues when num_inference_step is power of 3
396-
timesteps = (
397-
(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
416+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
398417
)
399-
timesteps -= 1
418+
sigmas = timesteps / self.config.num_train_timesteps
400419
else:
401-
raise ValueError(
402-
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
403-
)
420+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
421+
if self.config.timestep_spacing == "linspace":
422+
timesteps = np.linspace(
423+
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
424+
)[::-1].copy()
425+
elif self.config.timestep_spacing == "leading":
426+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
427+
# creates integer timesteps by multiplying by ratio
428+
# casting to int to avoid issues when num_inference_step is power of 3
429+
timesteps = (
430+
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
431+
)
432+
timesteps += self.config.steps_offset
433+
elif self.config.timestep_spacing == "trailing":
434+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
435+
# creates integer timesteps by multiplying by ratio
436+
# casting to int to avoid issues when num_inference_step is power of 3
437+
timesteps = (
438+
(np.arange(self.config.num_train_timesteps, 0, -step_ratio))
439+
.round()
440+
.copy()
441+
.astype(np.float32)
442+
)
443+
timesteps -= 1
444+
else:
445+
raise ValueError(
446+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
447+
)
448+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
449+
if self.config.interpolation_type == "linear":
450+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
451+
elif self.config.interpolation_type == "log_linear":
452+
sigmas = (
453+
torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1)
454+
.exp()
455+
.numpy()
456+
)
457+
else:
458+
raise ValueError(
459+
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
460+
" 'linear' or 'log_linear'"
461+
)
462+
463+
if self.config.use_flow_match:
464+
if self.config.use_dynamic_shifting:
465+
sigmas = self.time_shift(mu, 1.0, sigmas)
466+
else:
467+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
404468

405-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
406469
log_sigmas = np.log(sigmas)
407-
if self.config.interpolation_type == "linear":
408-
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
409-
elif self.config.interpolation_type == "log_linear":
410-
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
411-
else:
412-
raise ValueError(
413-
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
414-
" 'linear' or 'log_linear'"
415-
)
416470

417471
if self.config.use_karras_sigmas:
418472
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
@@ -426,10 +480,16 @@ def set_timesteps(
426480
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
427481
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
428482

483+
if self.config.invert_sigmas:
484+
sigmas = 1.0 - sigmas
485+
timesteps = sigmas * self.config.num_train_timesteps
486+
429487
if self.config.final_sigmas_type == "sigma_min":
430488
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
431489
elif self.config.final_sigmas_type == "zero":
432490
sigma_last = 0
491+
elif self.config.invert_sigmas:
492+
sigma_last = 1
433493
else:
434494
raise ValueError(
435495
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
@@ -442,14 +502,21 @@ def set_timesteps(
442502
# TODO: Support the full EDM scalings for all prediction types and timestep types
443503
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
444504
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
505+
elif self.config.use_flow_match:
506+
self.timesteps = sigmas * self.config.num_train_timesteps
445507
else:
446508
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
447509

448510
self._step_index = None
449511
self._begin_index = None
450512
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
451513

452-
def _sigma_to_t(self, sigma, log_sigmas):
514+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
515+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
516+
517+
def _sigma_to_t(self, sigma, log_sigmas=None):
518+
if self.config.use_flow_match:
519+
return sigma * self.config.num_train_timesteps
453520
# get log sigma
454521
log_sigma = np.log(np.maximum(sigma, 1e-10))
455522

@@ -622,7 +689,7 @@ def step(
622689
),
623690
)
624691

625-
if not self.is_scale_input_called:
692+
if not self.is_scale_input_called and not self.config.use_flow_match:
626693
logger.warning(
627694
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
628695
"See `StableDiffusionPipeline` for a usage example."
@@ -663,7 +730,10 @@ def step(
663730
)
664731

665732
# 2. Convert to an ODE derivative
666-
derivative = (sample - pred_original_sample) / sigma_hat
733+
if self.config.use_flow_match:
734+
derivative = model_output
735+
else:
736+
derivative = (sample - pred_original_sample) / sigma_hat
667737

668738
dt = self.sigmas[self.step_index + 1] - sigma_hat
669739

@@ -713,7 +783,10 @@ def add_noise(
713783
while len(sigma.shape) < len(original_samples.shape):
714784
sigma = sigma.unsqueeze(-1)
715785

716-
noisy_samples = original_samples + noise * sigma
786+
if self.config.use_flow_match:
787+
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
788+
else:
789+
noisy_samples = original_samples + noise * sigma
717790
return noisy_samples
718791

719792
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)