1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import math
16- from dataclasses import dataclass
1715from typing import List , Optional , Tuple , Union
1816
1917import numpy as np
2018import torch
2119
22- from ..configuration_utils import ConfigMixin , register_to_config
23- from ..utils import BaseOutput , logging
20+ from ..utils import logging
2421from ..utils .torch_utils import randn_tensor
25- from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
22+ from .scheduling_euler_ancestral_discrete import (
23+ EulerAncestralDiscreteScheduler ,
24+ EulerAncestralDiscreteSchedulerOutput ,
25+ rescale_zero_terminal_snr ,
26+ )
2627
2728
2829logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2930
3031
31- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
32- @dataclass
33- class EulerAncestralDiscreteXPredSchedulerOutput (BaseOutput ):
32+ class EulerAncestralDiscreteXPredScheduler (EulerAncestralDiscreteScheduler ):
3433 """
35- Output class for the scheduler's `step` function output.
36-
37- Args:
38- prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40- denoising loop.
41- pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42- The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43- `pred_original_sample` can be used to preview progress or for guidance.
44- """
45-
46- prev_sample : torch .FloatTensor
47- pred_original_sample : Optional [torch .FloatTensor ] = None
48-
49-
50- # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
51- def rescale_zero_terminal_snr (betas ):
52- """
53- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
54-
55-
56- Args:
57- betas (`torch.Tensor`):
58- the betas that the scheduler is being initialized with.
59-
60- Returns:
61- `torch.Tensor`: rescaled betas with zero terminal SNR
62- """
63- # Convert betas to alphas_bar_sqrt
64- alphas = 1.0 - betas
65- alphas_cumprod = torch .cumprod (alphas , dim = 0 )
66- alphas_bar_sqrt = alphas_cumprod .sqrt ()
67-
68- # Store old values.
69- alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
70- alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
71-
72- # Shift so the last timestep is zero.
73- alphas_bar_sqrt -= alphas_bar_sqrt_T
74-
75- # Scale so the first timestep is back to the old value.
76- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
77-
78- # Convert alphas_bar_sqrt to betas
79- alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
80- alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
81- alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
82- betas = 1 - alphas
83-
84- return betas
85-
86-
87- # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
88- def betas_for_alpha_bar (
89- num_diffusion_timesteps ,
90- max_beta = 0.999 ,
91- alpha_transform_type = "cosine" ,
92- ):
93- """
94- Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
95- (1-beta) over time from t = [0,1].
96-
97- Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
98- to that part of the diffusion process.
99-
100-
101- Args:
102- num_diffusion_timesteps (`int`): the number of betas to produce.
103- max_beta (`float`): the maximum beta to use; use values lower than 1 to
104- prevent singularities.
105- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
106- Choose from `cosine` or `exp`
107-
108- Returns:
109- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
110- """
111- if alpha_transform_type == "cosine" :
112-
113- def alpha_bar_fn (t ):
114- return math .cos ((t + 0.008 ) / 1.008 * math .pi / 2 ) ** 2
115-
116- elif alpha_transform_type == "exp" :
117-
118- def alpha_bar_fn (t ):
119- return math .exp (t * - 12.0 )
120-
121- else :
122- raise ValueError (f"Unsupported alpha_tranform_type: { alpha_transform_type } " )
123-
124- betas = []
125- for i in range (num_diffusion_timesteps ):
126- t1 = i / num_diffusion_timesteps
127- t2 = (i + 1 ) / num_diffusion_timesteps
128- betas .append (min (1 - alpha_bar_fn (t2 ) / alpha_bar_fn (t1 ), max_beta ))
129- return torch .tensor (betas , dtype = torch .float32 )
130-
131-
132- class EulerAncestralDiscreteXPredScheduler (SchedulerMixin , ConfigMixin ):
133- """
134- Ancestral sampling with Euler method steps.
135-
136- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
137- methods the library implements for all schedulers such as loading and saving.
34+ Ancestral sampling with Euler method steps. This model inherits from [`EulerAncestralDiscreteScheduler`]. Check the superclass
35+ documentation for the args and returns.
13836
13937 For more details, see the original paper: https://arxiv.org/abs/2403.08381
140-
141- Args:
142- num_train_timesteps (`int`, defaults to 1000):
143- The number of diffusion steps to train the model.
144- beta_start (`float`, defaults to 0.0001):
145- The starting `beta` value of inference.
146- beta_end (`float`, defaults to 0.02):
147- The final `beta` value.
148- beta_schedule (`str`, defaults to `"linear"`):
149- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
150- `linear` or `scaled_linear`.
151- trained_betas (`np.ndarray`, *optional*):
152- Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
153- prediction_type (`str`, defaults to `epsilon`, *optional*):
154- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
155- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
156- Video](https://imagen.research.google/video/paper.pdf) paper).
157- timestep_spacing (`str`, defaults to `"linspace"`):
158- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
159- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
160- steps_offset (`int`, defaults to 0):
161- An offset added to the inference steps, as required by some model families.
16238 """
16339
164- _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
165- order = 1
166-
167- @register_to_config
16840 def __init__ (
16941 self ,
17042 num_train_timesteps : int = 1000 ,
@@ -175,35 +47,21 @@ def __init__(
17547 prediction_type : str = "epsilon" ,
17648 timestep_spacing : str = "linspace" ,
17749 steps_offset : int = 0
178- ):
179- if trained_betas is not None :
180- self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
181- elif beta_schedule == "linear" :
182- self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
183- elif beta_schedule == "scaled_linear" :
184- # this schedule is very specific to the latent diffusion model.
185- self .betas = (
186- torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
50+ ):
51+ super (EulerAncestralDiscreteXPredScheduler , self ).__init__ (
52+ num_train_timesteps ,
53+ beta_start ,
54+ beta_end ,
55+ beta_schedule ,
56+ trained_betas ,
57+ prediction_type ,
58+ timestep_spacing ,
59+ steps_offset
18760 )
188- elif beta_schedule == "squaredcos_cap_v2" :
189- # Glide cosine schedule
190- self .betas = betas_for_alpha_bar (num_train_timesteps )
191- else :
192- raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
193-
194-
195- self .alphas = 1.0 - self .betas
196- self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
19761
19862 sigmas = np .array (((1 - self .alphas_cumprod )) ** 0.5 , dtype = np .float32 )
19963 self .sigmas = torch .from_numpy (sigmas )
20064
201- # setable values
202- self .num_inference_steps = None
203- timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
204- self .timesteps = torch .from_numpy (timesteps )
205- self .is_scale_input_called = False
206-
20765 def rescale_betas_zero_snr (self ):
20866 self .betas = rescale_zero_terminal_snr (self .betas )
20967 self .alphas = 1.0 - self .betas
@@ -274,7 +132,7 @@ def step(
274132 sample : torch .FloatTensor ,
275133 generator : Optional [torch .Generator ] = None ,
276134 return_dict : bool = True ,
277- ) -> Union [EulerAncestralDiscreteXPredSchedulerOutput , Tuple ]:
135+ ) -> Union [EulerAncestralDiscreteSchedulerOutput , Tuple ]:
278136 """
279137 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
280138 process from the learned model outputs (most often the predicted noise).
@@ -285,11 +143,11 @@ def step(
285143 sample (`torch.FloatTensor`):
286144 current instance of sample being created by diffusion process.
287145 generator (`torch.Generator`, optional): Random number generator.
288- return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteXPredSchedulerOutput class
146+ return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
289147
290148 Returns:
291- [`~schedulers.scheduling_utils.EulerAncestralDiscreteXPredSchedulerOutput `] or `tuple`:
292- [`~schedulers.scheduling_utils.EulerAncestralDiscreteXPredSchedulerOutput `] if `return_dict` is True, otherwise
149+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput `] or `tuple`:
150+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput `] if `return_dict` is True, otherwise
293151 a `tuple`. When returning a tuple, the first element is the sample tensor.
294152
295153 """
@@ -312,12 +170,7 @@ def step(
312170
313171 step_index = (self .timesteps == timestep ).nonzero ().item ()
314172
315- if self .config .prediction_type == "epsilon" :
316- pred_original_sample = sample - sigma * model_output
317- elif self .config .prediction_type == "v_prediction" :
318- # * c_out + input * c_skip
319- pred_original_sample = model_output * (- sigma / (sigma ** 2 + 1 ) ** 0.5 ) + (sample / (sigma ** 2 + 1 ))
320- elif self .config .prediction_type == "sample" :
173+ if self .config .prediction_type == "sample" :
321174 pred_original_sample = model_output
322175 else :
323176 raise ValueError (
@@ -340,11 +193,10 @@ def step(
340193 if not return_dict :
341194 return (prev_sample ,)
342195
343- return EulerAncestralDiscreteXPredSchedulerOutput (
196+ return EulerAncestralDiscreteSchedulerOutput (
344197 prev_sample = prev_sample , pred_original_sample = pred_original_sample
345198 )
346199
347- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
348200 def add_noise (
349201 self ,
350202 original_samples : torch .FloatTensor ,
@@ -369,6 +221,3 @@ def add_noise(
369221
370222 noisy_samples = original_samples + noise * sigma
371223 return noisy_samples
372-
373- def __len__ (self ):
374- return self .config .num_train_timesteps
0 commit comments