Skip to content

Commit a4429e0

Browse files
authored
Merge branch 'main' into dreambooth-lora-flux-exploration
2 parents f62af61 + 99d8747 commit a4429e0

20 files changed

+205
-75
lines changed

benchmarks/push_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas as pd
55
from huggingface_hub import hf_hub_download, upload_file
6-
from huggingface_hub.utils._errors import EntryNotFoundError
6+
from huggingface_hub.utils import EntryNotFoundError
77

88

99
sys.path.append(".")

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def __init__(
225225
requires_safety_checker: bool = True,
226226
):
227227
super().__init__()
228+
if isinstance(controlnet, (list, tuple)):
229+
controlnet = HunyuanDiT2DMultiControlNetModel(controlnet)
228230

229231
self.register_modules(
230232
vae=vae,

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def __init__(
192192
],
193193
):
194194
super().__init__()
195+
if isinstance(controlnet, (list, tuple)):
196+
controlnet = SD3MultiControlNetModel(controlnet)
195197

196198
self.register_modules(
197199
vae=vae,

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def step(
463463
prev_sample = prev_sample + variance
464464

465465
if not return_dict:
466-
return (prev_sample,)
466+
return (
467+
prev_sample,
468+
pred_original_sample,
469+
)
467470

468471
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
469472

src/diffusers/schedulers/scheduling_ddim_cogvideox.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def step(
394394
prev_sample = a_t * sample + b_t * pred_original_sample
395395

396396
if not return_dict:
397-
return (prev_sample,)
397+
return (
398+
prev_sample,
399+
pred_original_sample,
400+
)
398401

399402
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
400403

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,10 @@ def step(
480480
prev_sample = prev_sample + variance
481481

482482
if not return_dict:
483-
return (prev_sample,)
483+
return (
484+
prev_sample,
485+
pred_original_sample,
486+
)
484487

485488
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
486489

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,10 @@ def step(
492492
pred_prev_sample = pred_prev_sample + variance
493493

494494
if not return_dict:
495-
return (pred_prev_sample,)
495+
return (
496+
pred_prev_sample,
497+
pred_original_sample,
498+
)
496499

497500
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
498501

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,10 @@ def step(
500500
pred_prev_sample = pred_prev_sample + variance
501501

502502
if not return_dict:
503-
return (pred_prev_sample,)
503+
return (
504+
pred_prev_sample,
505+
pred_original_sample,
506+
)
504507

505508
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
506509

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,41 @@
1313
# limitations under the License.
1414

1515
import math
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Union
1718

1819
import numpy as np
1920
import torch
2021
import torchsde
2122

2223
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import is_scipy_available
24-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
24+
from ..utils import BaseOutput, is_scipy_available
25+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2526

2627

2728
if is_scipy_available():
2829
import scipy.stats
2930

3031

32+
@dataclass
33+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
34+
class DPMSolverSDESchedulerOutput(BaseOutput):
35+
"""
36+
Output class for the scheduler's `step` function output.
37+
38+
Args:
39+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41+
denoising loop.
42+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44+
`pred_original_sample` can be used to preview progress or for guidance.
45+
"""
46+
47+
prev_sample: torch.Tensor
48+
pred_original_sample: Optional[torch.Tensor] = None
49+
50+
3151
class BatchedBrownianTree:
3252
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
3353

@@ -510,7 +530,7 @@ def step(
510530
sample: Union[torch.Tensor, np.ndarray],
511531
return_dict: bool = True,
512532
s_noise: float = 1.0,
513-
) -> Union[SchedulerOutput, Tuple]:
533+
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
514534
"""
515535
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
516536
process from the learned model outputs (most often the predicted noise).
@@ -522,15 +542,16 @@ def step(
522542
The current discrete timestep in the diffusion chain.
523543
sample (`torch.Tensor` or `np.ndarray`):
524544
A current instance of a sample created by the diffusion process.
525-
return_dict (`bool`, *optional*, defaults to `True`):
526-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
545+
return_dict (`bool`):
546+
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
547+
tuple.
527548
s_noise (`float`, *optional*, defaults to 1.0):
528549
Scaling factor for noise added to the sample.
529550
530551
Returns:
531-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
532-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
533-
tuple is returned where the first element is the sample tensor.
552+
[`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
553+
If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
554+
returned, otherwise a tuple is returned where the first element is the sample tensor.
534555
"""
535556
if self.step_index is None:
536557
self._init_step_index(timestep)
@@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
610631
self._step_index += 1
611632

612633
if not return_dict:
613-
return (prev_sample,)
634+
return (
635+
prev_sample,
636+
pred_original_sample,
637+
)
614638

615-
return SchedulerOutput(prev_sample=prev_sample)
639+
return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
616640

617641
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
618642
def add_noise(

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,13 @@ def step(
333333

334334
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
335335

336-
noise = randn_tensor(
337-
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
338-
)
339-
340-
eps = noise * s_noise
341336
sigma_hat = sigma * (gamma + 1)
342337

343338
if gamma > 0:
339+
noise = randn_tensor(
340+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
341+
)
342+
eps = noise * s_noise
344343
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
345344

346345
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -360,7 +359,10 @@ def step(
360359
self._step_index += 1
361360

362361
if not return_dict:
363-
return (prev_sample,)
362+
return (
363+
prev_sample,
364+
pred_original_sample,
365+
)
364366

365367
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
366368

0 commit comments

Comments
 (0)