Skip to content

Commit 8f759c2

Browse files
committed
Refactor SchedulerOutput and add pred_original_sample
1 parent 164ec9f commit 8f759c2

File tree

4 files changed

+131
-35
lines changed

4 files changed

+131
-35
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,41 @@
1313
# limitations under the License.
1414

1515
import math
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Union
1718

1819
import numpy as np
1920
import torch
2021
import torchsde
2122

2223
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import is_scipy_available
24-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
24+
from ..utils import BaseOutput, is_scipy_available
25+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2526

2627

2728
if is_scipy_available():
2829
import scipy.stats
2930

3031

32+
@dataclass
33+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
34+
class DPMSolverSDESchedulerOutput(BaseOutput):
35+
"""
36+
Output class for the scheduler's `step` function output.
37+
38+
Args:
39+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41+
denoising loop.
42+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44+
`pred_original_sample` can be used to preview progress or for guidance.
45+
"""
46+
47+
prev_sample: torch.Tensor
48+
pred_original_sample: Optional[torch.Tensor] = None
49+
50+
3151
class BatchedBrownianTree:
3252
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
3353

@@ -510,7 +530,7 @@ def step(
510530
sample: Union[torch.Tensor, np.ndarray],
511531
return_dict: bool = True,
512532
s_noise: float = 1.0,
513-
) -> Union[SchedulerOutput, Tuple]:
533+
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
514534
"""
515535
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
516536
process from the learned model outputs (most often the predicted noise).
@@ -522,15 +542,16 @@ def step(
522542
The current discrete timestep in the diffusion chain.
523543
sample (`torch.Tensor` or `np.ndarray`):
524544
A current instance of a sample created by the diffusion process.
525-
return_dict (`bool`, *optional*, defaults to `True`):
526-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
545+
return_dict (`bool`):
546+
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
547+
tuple.
527548
s_noise (`float`, *optional*, defaults to 1.0):
528549
Scaling factor for noise added to the sample.
529550
530551
Returns:
531-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
532-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
533-
tuple is returned where the first element is the sample tensor.
552+
[`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
553+
If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
554+
returned, otherwise a tuple is returned where the first element is the sample tensor.
534555
"""
535556
if self.step_index is None:
536557
self._init_step_index(timestep)
@@ -610,9 +631,12 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
610631
self._step_index += 1
611632

612633
if not return_dict:
613-
return (prev_sample,)
634+
return (
635+
prev_sample,
636+
pred_original_sample,
637+
)
614638

615-
return SchedulerOutput(prev_sample=prev_sample)
639+
return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
616640

617641
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
618642
def add_noise(

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,40 @@
1313
# limitations under the License.
1414

1515
import math
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Union
1718

1819
import numpy as np
1920
import torch
2021

2122
from ..configuration_utils import ConfigMixin, register_to_config
22-
from ..utils import is_scipy_available
23-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
23+
from ..utils import BaseOutput, is_scipy_available
24+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2425

2526

2627
if is_scipy_available():
2728
import scipy.stats
2829

2930

31+
@dataclass
32+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
33+
class HeunDiscreteSchedulerOutput(BaseOutput):
34+
"""
35+
Output class for the scheduler's `step` function output.
36+
37+
Args:
38+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40+
denoising loop.
41+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43+
`pred_original_sample` can be used to preview progress or for guidance.
44+
"""
45+
46+
prev_sample: torch.Tensor
47+
pred_original_sample: Optional[torch.Tensor] = None
48+
49+
3050
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3151
def betas_for_alpha_bar(
3252
num_diffusion_timesteps,
@@ -455,7 +475,7 @@ def step(
455475
timestep: Union[float, torch.Tensor],
456476
sample: Union[torch.Tensor, np.ndarray],
457477
return_dict: bool = True,
458-
) -> Union[SchedulerOutput, Tuple]:
478+
) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
459479
"""
460480
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
461481
process from the learned model outputs (most often the predicted noise).
@@ -468,12 +488,13 @@ def step(
468488
sample (`torch.Tensor`):
469489
A current instance of a sample created by the diffusion process.
470490
return_dict (`bool`):
471-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
491+
Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
492+
tuple.
472493
473494
Returns:
474-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
475-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
476-
tuple is returned where the first element is the sample tensor.
495+
[`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
496+
If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
497+
returned, otherwise a tuple is returned where the first element is the sample tensor.
477498
"""
478499
if self.step_index is None:
479500
self._init_step_index(timestep)
@@ -544,9 +565,12 @@ def step(
544565
self._step_index += 1
545566

546567
if not return_dict:
547-
return (prev_sample,)
568+
return (
569+
prev_sample,
570+
pred_original_sample,
571+
)
548572

549-
return SchedulerOutput(prev_sample=prev_sample)
573+
return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
550574

551575
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
552576
def add_noise(

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,41 @@
1313
# limitations under the License.
1414

1515
import math
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Union
1718

1819
import numpy as np
1920
import torch
2021

2122
from ..configuration_utils import ConfigMixin, register_to_config
22-
from ..utils import is_scipy_available
23+
from ..utils import BaseOutput, is_scipy_available
2324
from ..utils.torch_utils import randn_tensor
24-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2526

2627

2728
if is_scipy_available():
2829
import scipy.stats
2930

3031

32+
@dataclass
33+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
34+
class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
35+
"""
36+
Output class for the scheduler's `step` function output.
37+
38+
Args:
39+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41+
denoising loop.
42+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44+
`pred_original_sample` can be used to preview progress or for guidance.
45+
"""
46+
47+
prev_sample: torch.Tensor
48+
pred_original_sample: Optional[torch.Tensor] = None
49+
50+
3151
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3252
def betas_for_alpha_bar(
3353
num_diffusion_timesteps,
@@ -459,7 +479,7 @@ def step(
459479
sample: Union[torch.Tensor, np.ndarray],
460480
generator: Optional[torch.Generator] = None,
461481
return_dict: bool = True,
462-
) -> Union[SchedulerOutput, Tuple]:
482+
) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
463483
"""
464484
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
465485
process from the learned model outputs (most often the predicted noise).
@@ -474,11 +494,11 @@ def step(
474494
generator (`torch.Generator`, *optional*):
475495
A random number generator.
476496
return_dict (`bool`):
477-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
497+
Whether or not to return a [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.
478498
479499
Returns:
480-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
481-
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
500+
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
501+
If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is returned, otherwise a
482502
tuple is returned where the first element is the sample tensor.
483503
"""
484504
if self.step_index is None:
@@ -548,9 +568,14 @@ def step(
548568
self._step_index += 1
549569

550570
if not return_dict:
551-
return (prev_sample,)
571+
return (
572+
prev_sample,
573+
pred_original_sample,
574+
)
552575

553-
return SchedulerOutput(prev_sample=prev_sample)
576+
return KDPM2AncestralDiscreteSchedulerOutput(
577+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
578+
)
554579

555580
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
556581
def add_noise(

src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,40 @@
1313
# limitations under the License.
1414

1515
import math
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Union
1718

1819
import numpy as np
1920
import torch
2021

2122
from ..configuration_utils import ConfigMixin, register_to_config
22-
from ..utils import is_scipy_available
23-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
23+
from ..utils import BaseOutput, is_scipy_available
24+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
2425

2526

2627
if is_scipy_available():
2728
import scipy.stats
2829

2930

31+
@dataclass
32+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete
33+
class KDPM2DiscreteSchedulerOutput(BaseOutput):
34+
"""
35+
Output class for the scheduler's `step` function output.
36+
37+
Args:
38+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40+
denoising loop.
41+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43+
`pred_original_sample` can be used to preview progress or for guidance.
44+
"""
45+
46+
prev_sample: torch.Tensor
47+
pred_original_sample: Optional[torch.Tensor] = None
48+
49+
3050
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3151
def betas_for_alpha_bar(
3252
num_diffusion_timesteps,
@@ -443,7 +463,7 @@ def step(
443463
timestep: Union[float, torch.Tensor],
444464
sample: Union[torch.Tensor, np.ndarray],
445465
return_dict: bool = True,
446-
) -> Union[SchedulerOutput, Tuple]:
466+
) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]:
447467
"""
448468
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
449469
process from the learned model outputs (most often the predicted noise).
@@ -456,11 +476,11 @@ def step(
456476
sample (`torch.Tensor`):
457477
A current instance of a sample created by the diffusion process.
458478
return_dict (`bool`):
459-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
479+
Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or tuple.
460480
461481
Returns:
462-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
463-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
482+
[`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`:
483+
If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is returned, otherwise a
464484
tuple is returned where the first element is the sample tensor.
465485
"""
466486
if self.step_index is None:
@@ -523,9 +543,12 @@ def step(
523543
prev_sample = sample + derivative * dt
524544

525545
if not return_dict:
526-
return (prev_sample,)
546+
return (
547+
prev_sample,
548+
pred_original_sample,
549+
)
527550

528-
return SchedulerOutput(prev_sample=prev_sample)
551+
return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
529552

530553
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
531554
def add_noise(

0 commit comments

Comments
 (0)