Skip to content

Commit 488fb7b

Browse files
committed
EulerAncestralDiscreteScheduler
1 parent 323806c commit 488fb7b

File tree

1 file changed

+13
-105
lines changed

1 file changed

+13
-105
lines changed

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 13 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import ConfigMixin, register_to_config
2121
from ..utils import BaseOutput, logging
2222
from ..utils.torch_utils import randn_tensor
23-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
23+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin
2424

2525

2626
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -45,7 +45,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
4545
pred_original_sample: Optional[torch.Tensor] = None
4646

4747

48-
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
48+
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin, SamplingMixin):
4949
"""
5050
Ancestral sampling with Euler method steps.
5151
@@ -92,46 +92,17 @@ def __init__(
9292
self.set_schedule(schedule_config)
9393
self.set_sigma_schedule(sigma_schedule_config)
9494

95-
# setable values
96-
self.num_inference_steps = None
97-
98-
self.is_scale_input_called = False
99-
self._step_index = None
100-
self._begin_index = None
101-
10295
@property
96+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.init_noise_sigma
10397
def init_noise_sigma(self):
10498
# standard deviation of the initial noise distribution
99+
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
105100
if self.config.timestep_spacing in ["linspace", "trailing"]:
106-
return self.sigmas.max()
107-
108-
return (self.sigmas.max() ** 2 + 1) ** 0.5
109-
110-
@property
111-
def step_index(self):
112-
"""
113-
The index counter for current timestep. It will increase 1 after each scheduler step.
114-
"""
115-
return self._step_index
101+
return max_sigma
116102

117-
@property
118-
def begin_index(self):
119-
"""
120-
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
121-
"""
122-
return self._begin_index
123-
124-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
125-
def set_begin_index(self, begin_index: int = 0):
126-
"""
127-
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
128-
129-
Args:
130-
begin_index (`int`):
131-
The begin index for the scheduler.
132-
"""
133-
self._begin_index = begin_index
103+
return (max_sigma**2 + 1) ** 0.5
134104

105+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input
135106
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
136107
"""
137108
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -147,18 +118,19 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
147118
`torch.Tensor`:
148119
A scaled input sample.
149120
"""
150-
151121
if self.step_index is None:
152122
self._init_step_index(timestep)
153123

154124
sigma = self.sigmas[self.step_index]
155125
sample = sample / ((sigma**2 + 1) ** 0.5)
126+
156127
self.is_scale_input_called = True
157128
return sample
158129

130+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.set_timesteps
159131
def set_timesteps(
160132
self,
161-
num_inference_steps: int,
133+
num_inference_steps: int = None,
162134
device: Union[str, torch.device] = None,
163135
timesteps: Optional[List[int]] = None,
164136
sigmas: Optional[List[float]] = None,
@@ -201,11 +173,11 @@ def set_timesteps(
201173
raise ValueError("Cannot set `timesteps` with `BetaSigmas`.")
202174
if (
203175
timesteps is not None
204-
and self._schedule.config.get("timestep_type", None) == "continuous"
176+
and self._schedule.timestep_type == "continuous"
205177
and self.config.prediction_type == "v_prediction"
206178
):
207179
raise ValueError(
208-
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
180+
"Cannot set `timesteps` with `schedule.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
209181
)
210182

211183
if num_inference_steps is None:
@@ -222,35 +194,11 @@ def set_timesteps(
222194
shift=shift,
223195
)
224196

225-
self.timesteps = timesteps.to(device=device)
226197
self._step_index = None
227198
self._begin_index = None
199+
self.timesteps = timesteps.to(device=device)
228200
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
229201

230-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
231-
def index_for_timestep(self, timestep, schedule_timesteps=None):
232-
if schedule_timesteps is None:
233-
schedule_timesteps = self.timesteps
234-
235-
indices = (schedule_timesteps == timestep).nonzero()
236-
237-
# The sigma index that is taken for the **very** first `step`
238-
# is always the second index (or the last index if there is only 1)
239-
# This way we can ensure we don't accidentally skip a sigma in
240-
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
241-
pos = 1 if len(indices) > 1 else 0
242-
243-
return indices[pos].item()
244-
245-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
246-
def _init_step_index(self, timestep):
247-
if self.begin_index is None:
248-
if isinstance(timestep, torch.Tensor):
249-
timestep = timestep.to(self.timesteps.device)
250-
self._step_index = self.index_for_timestep(timestep)
251-
else:
252-
self._step_index = self._begin_index
253-
254202
def step(
255203
self,
256204
model_output: torch.Tensor,
@@ -352,43 +300,3 @@ def step(
352300
return EulerAncestralDiscreteSchedulerOutput(
353301
prev_sample=prev_sample, pred_original_sample=pred_original_sample
354302
)
355-
356-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
357-
def add_noise(
358-
self,
359-
original_samples: torch.Tensor,
360-
noise: torch.Tensor,
361-
timesteps: torch.Tensor,
362-
) -> torch.Tensor:
363-
# Make sure sigmas and timesteps have the same device and dtype as original_samples
364-
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
365-
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
366-
# mps does not support float64
367-
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
368-
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
369-
else:
370-
schedule_timesteps = self.timesteps.to(original_samples.device)
371-
timesteps = timesteps.to(original_samples.device)
372-
373-
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
374-
if self.begin_index is None:
375-
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
376-
elif self.step_index is not None:
377-
# add_noise is called after first denoising step (for inpainting)
378-
step_indices = [self.step_index] * timesteps.shape[0]
379-
else:
380-
# add noise is called before first denoising step to create initial latent(img2img)
381-
step_indices = [self.begin_index] * timesteps.shape[0]
382-
383-
sigma = sigmas[step_indices].flatten()
384-
while len(sigma.shape) < len(original_samples.shape):
385-
sigma = sigma.unsqueeze(-1)
386-
387-
if self._schedule.__class__.__name__ == "FlowMatchSchedule":
388-
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
389-
else:
390-
noisy_samples = original_samples + noise * sigma
391-
return noisy_samples
392-
393-
def __len__(self):
394-
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)