Skip to content

Commit de5fe50

Browse files
committed
EulerDiscreteScheduler
1 parent 488fb7b commit de5fe50

File tree

1 file changed

+7
-146
lines changed

1 file changed

+7
-146
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 7 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import ConfigMixin, register_to_config
2121
from ..utils import BaseOutput, logging
2222
from ..utils.torch_utils import randn_tensor
23-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
23+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin
2424

2525

2626
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -45,7 +45,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
4545
pred_original_sample: Optional[torch.Tensor] = None
4646

4747

48-
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
48+
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin, SamplingMixin):
4949
"""
5050
Euler scheduler.
5151
@@ -106,13 +106,6 @@ def __init__(
106106
self.set_schedule(schedule_config)
107107
self.set_sigma_schedule(sigma_schedule_config)
108108

109-
# setable values
110-
self.num_inference_steps = None
111-
112-
self.is_scale_input_called = False
113-
self._step_index = None
114-
self._begin_index = None
115-
116109
@property
117110
def init_noise_sigma(self):
118111
# standard deviation of the initial noise distribution
@@ -122,31 +115,6 @@ def init_noise_sigma(self):
122115

123116
return (max_sigma**2 + 1) ** 0.5
124117

125-
@property
126-
def step_index(self):
127-
"""
128-
The index counter for current timestep. It will increase 1 after each scheduler step.
129-
"""
130-
return self._step_index
131-
132-
@property
133-
def begin_index(self):
134-
"""
135-
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
136-
"""
137-
return self._begin_index
138-
139-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
140-
def set_begin_index(self, begin_index: int = 0):
141-
"""
142-
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
143-
144-
Args:
145-
begin_index (`int`):
146-
The begin index for the scheduler.
147-
"""
148-
self._begin_index = begin_index
149-
150118
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
151119
"""
152120
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -188,15 +156,6 @@ def set_timesteps(
188156
The number of diffusion steps used when generating samples with a pre-trained model.
189157
device (`str` or `torch.device`, *optional*):
190158
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
191-
timesteps (`List[int]`, *optional*):
192-
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
193-
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
194-
must be `None`, and `timestep_spacing` attribute will be ignored.
195-
sigmas (`List[float]`, *optional*):
196-
Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
197-
will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
198-
`num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
199-
custom sigmas schedule.
200159
"""
201160

202161
if timesteps is not None and sigmas is not None:
@@ -210,7 +169,7 @@ def set_timesteps(
210169
and self._sigma_schedule is not None
211170
and self._sigma_schedule.__class__.__name__ == "KarrasSigmas"
212171
):
213-
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
172+
raise ValueError("Cannot set `timesteps` with `KarrasSigmas`.")
214173
if (
215174
timesteps is not None
216175
and self._sigma_schedule is not None
@@ -225,11 +184,11 @@ def set_timesteps(
225184
raise ValueError("Cannot set `timesteps` with `BetaSigmas`.")
226185
if (
227186
timesteps is not None
228-
and self._schedule.config.get("timestep_type", None) == "continuous"
187+
and self._schedule.timestep_type == "continuous"
229188
and self.config.prediction_type == "v_prediction"
230189
):
231190
raise ValueError(
232-
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
191+
"Cannot set `timesteps` with `schedule.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
233192
)
234193

235194
if num_inference_steps is None:
@@ -248,30 +207,8 @@ def set_timesteps(
248207

249208
self._step_index = None
250209
self._begin_index = None
251-
self.timesteps = timesteps
252-
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
253-
254-
def index_for_timestep(self, timestep, schedule_timesteps=None):
255-
if schedule_timesteps is None:
256-
schedule_timesteps = self.timesteps
257-
258-
indices = (schedule_timesteps == timestep).nonzero()
259-
260-
# The sigma index that is taken for the **very** first `step`
261-
# is always the second index (or the last index if there is only 1)
262-
# This way we can ensure we don't accidentally skip a sigma in
263-
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
264-
pos = 1 if len(indices) > 1 else 0
265-
266-
return indices[pos].item()
267-
268-
def _init_step_index(self, timestep):
269-
if self.begin_index is None:
270-
if isinstance(timestep, torch.Tensor):
271-
timestep = timestep.to(self.timesteps.device)
272-
self._step_index = self.index_for_timestep(timestep)
273-
else:
274-
self._step_index = self._begin_index
210+
self.timesteps = timesteps.to(device=device)
211+
self.sigmas = sigmas.to("cpu")
275212

276213
def step(
277214
self,
@@ -382,79 +319,3 @@ def step(
382319
)
383320

384321
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
385-
386-
def add_noise(
387-
self,
388-
original_samples: torch.Tensor,
389-
noise: torch.Tensor,
390-
timesteps: torch.Tensor,
391-
) -> torch.Tensor:
392-
# Make sure sigmas and timesteps have the same device and dtype as original_samples
393-
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
394-
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
395-
# mps does not support float64
396-
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
397-
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
398-
else:
399-
schedule_timesteps = self.timesteps.to(original_samples.device)
400-
timesteps = timesteps.to(original_samples.device)
401-
402-
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
403-
if self.begin_index is None:
404-
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
405-
elif self.step_index is not None:
406-
# add_noise is called after first denoising step (for inpainting)
407-
step_indices = [self.step_index] * timesteps.shape[0]
408-
else:
409-
# add noise is called before first denoising step to create initial latent(img2img)
410-
step_indices = [self.begin_index] * timesteps.shape[0]
411-
412-
sigma = sigmas[step_indices].flatten()
413-
while len(sigma.shape) < len(original_samples.shape):
414-
sigma = sigma.unsqueeze(-1)
415-
416-
if self._schedule.__class__.__name__ == "FlowMatchSchedule":
417-
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
418-
else:
419-
noisy_samples = original_samples + noise * sigma
420-
return noisy_samples
421-
422-
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
423-
if (
424-
isinstance(timesteps, int)
425-
or isinstance(timesteps, torch.IntTensor)
426-
or isinstance(timesteps, torch.LongTensor)
427-
):
428-
raise ValueError(
429-
(
430-
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
431-
" `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
432-
" one of the `scheduler.timesteps` as a timestep."
433-
),
434-
)
435-
436-
if sample.device.type == "mps" and torch.is_floating_point(timesteps):
437-
# mps does not support float64
438-
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
439-
timesteps = timesteps.to(sample.device, dtype=torch.float32)
440-
else:
441-
schedule_timesteps = self.timesteps.to(sample.device)
442-
timesteps = timesteps.to(sample.device)
443-
444-
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
445-
alphas_cumprod = self.alphas_cumprod.to(sample)
446-
sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
447-
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
448-
while len(sqrt_alpha_prod.shape) < len(sample.shape):
449-
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
450-
451-
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
452-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
453-
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
454-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
455-
456-
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
457-
return velocity
458-
459-
def __len__(self):
460-
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)