Skip to content

Commit 2c09f65

Browse files
authored
Merge branch 'main' into main
2 parents 75a7d85 + 5f0df17 commit 2c09f65

16 files changed

+180
-47
lines changed

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def __init__(
225225
requires_safety_checker: bool = True,
226226
):
227227
super().__init__()
228+
if isinstance(controlnet, (list, tuple)):
229+
controlnet = HunyuanDiT2DMultiControlNetModel(controlnet)
228230

229231
self.register_modules(
230232
vae=vae,

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def __init__(
192192
],
193193
):
194194
super().__init__()
195+
if isinstance(controlnet, (list, tuple)):
196+
controlnet = SD3MultiControlNetModel(controlnet)
195197

196198
self.register_modules(
197199
vae=vae,

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def step(
463463
prev_sample = prev_sample + variance
464464

465465
if not return_dict:
466-
return (prev_sample,)
466+
return (
467+
prev_sample,
468+
pred_original_sample,
469+
)
467470

468471
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
469472

src/diffusers/schedulers/scheduling_ddim_cogvideox.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def step(
394394
prev_sample = a_t * sample + b_t * pred_original_sample
395395

396396
if not return_dict:
397-
return (prev_sample,)
397+
return (
398+
prev_sample,
399+
pred_original_sample,
400+
)
398401

399402
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
400403

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,10 @@ def step(
480480
prev_sample = prev_sample + variance
481481

482482
if not return_dict:
483-
return (prev_sample,)
483+
return (
484+
prev_sample,
485+
pred_original_sample,
486+
)
484487

485488
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
486489

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,10 @@ def step(
492492
pred_prev_sample = pred_prev_sample + variance
493493

494494
if not return_dict:
495-
return (pred_prev_sample,)
495+
return (
496+
pred_prev_sample,
497+
pred_original_sample,
498+
)
496499

497500
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
498501

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,10 @@ def step(
500500
pred_prev_sample = pred_prev_sample + variance
501501

502502
if not return_dict:
503-
return (pred_prev_sample,)
503+
return (
504+
pred_prev_sample,
505+
pred_original_sample,
506+
)
504507

505508
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
506509

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,41 @@
1313
# limitations under the License.
1414

1515
import math
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Union
1718

1819
import numpy as np
1920
import torch
2021
import torchsde
2122

2223
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import is_scipy_available
24-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
24+
from ..utils import BaseOutput, is_scipy_available
25+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2526

2627

2728
if is_scipy_available():
2829
import scipy.stats
2930

3031

32+
@dataclass
33+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
34+
class DPMSolverSDESchedulerOutput(BaseOutput):
35+
"""
36+
Output class for the scheduler's `step` function output.
37+
38+
Args:
39+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41+
denoising loop.
42+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44+
`pred_original_sample` can be used to preview progress or for guidance.
45+
"""
46+
47+
prev_sample: torch.Tensor
48+
pred_original_sample: Optional[torch.Tensor] = None
49+
50+
3151
class BatchedBrownianTree:
3252
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
3353

@@ -510,7 +530,7 @@ def step(
510530
sample: Union[torch.Tensor, np.ndarray],
511531
return_dict: bool = True,
512532
s_noise: float = 1.0,
513-
) -> Union[SchedulerOutput, Tuple]:
533+
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
514534
"""
515535
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
516536
process from the learned model outputs (most often the predicted noise).
@@ -522,15 +542,16 @@ def step(
522542
The current discrete timestep in the diffusion chain.
523543
sample (`torch.Tensor` or `np.ndarray`):
524544
A current instance of a sample created by the diffusion process.
525-
return_dict (`bool`, *optional*, defaults to `True`):
526-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
545+
return_dict (`bool`):
546+
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
547+
tuple.
527548
s_noise (`float`, *optional*, defaults to 1.0):
528549
Scaling factor for noise added to the sample.
529550
530551
Returns:
531-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
532-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
533-
tuple is returned where the first element is the sample tensor.
552+
[`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
553+
If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
554+
returned, otherwise a tuple is returned where the first element is the sample tensor.
534555
"""
535556
if self.step_index is None:
536557
self._init_step_index(timestep)
@@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
610631
self._step_index += 1
611632

612633
if not return_dict:
613-
return (prev_sample,)
634+
return (
635+
prev_sample,
636+
pred_original_sample,
637+
)
614638

615-
return SchedulerOutput(prev_sample=prev_sample)
639+
return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
616640

617641
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
618642
def add_noise(

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,10 @@ def step(
360360
self._step_index += 1
361361

362362
if not return_dict:
363-
return (prev_sample,)
363+
return (
364+
prev_sample,
365+
pred_original_sample,
366+
)
364367

365368
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
366369

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,10 @@ def step(
435435
self._step_index += 1
436436

437437
if not return_dict:
438-
return (prev_sample,)
438+
return (
439+
prev_sample,
440+
pred_original_sample,
441+
)
439442

440443
return EulerAncestralDiscreteSchedulerOutput(
441444
prev_sample=prev_sample, pred_original_sample=pred_original_sample

0 commit comments

Comments
 (0)