2020from ..configuration_utils import ConfigMixin , register_to_config
2121from ..utils import BaseOutput , logging
2222from ..utils .torch_utils import randn_tensor
23- from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
23+ from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SamplingMixin
2424
2525
2626logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -45,7 +45,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
4545 pred_original_sample : Optional [torch .Tensor ] = None
4646
4747
48- class EulerAncestralDiscreteScheduler (SchedulerMixin , ConfigMixin ):
48+ class EulerAncestralDiscreteScheduler (SchedulerMixin , ConfigMixin , SamplingMixin ):
4949 """
5050 Ancestral sampling with Euler method steps.
5151
@@ -92,46 +92,17 @@ def __init__(
9292 self .set_schedule (schedule_config )
9393 self .set_sigma_schedule (sigma_schedule_config )
9494
95- # setable values
96- self .num_inference_steps = None
97-
98- self .is_scale_input_called = False
99- self ._step_index = None
100- self ._begin_index = None
101-
10295 @property
96+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.init_noise_sigma
10397 def init_noise_sigma (self ):
10498 # standard deviation of the initial noise distribution
99+ max_sigma = max (self .sigmas ) if isinstance (self .sigmas , list ) else self .sigmas .max ()
105100 if self .config .timestep_spacing in ["linspace" , "trailing" ]:
106- return self .sigmas .max ()
107-
108- return (self .sigmas .max () ** 2 + 1 ) ** 0.5
109-
110- @property
111- def step_index (self ):
112- """
113- The index counter for current timestep. It will increase 1 after each scheduler step.
114- """
115- return self ._step_index
101+ return max_sigma
116102
117- @property
118- def begin_index (self ):
119- """
120- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
121- """
122- return self ._begin_index
123-
124- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
125- def set_begin_index (self , begin_index : int = 0 ):
126- """
127- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
128-
129- Args:
130- begin_index (`int`):
131- The begin index for the scheduler.
132- """
133- self ._begin_index = begin_index
103+ return (max_sigma ** 2 + 1 ) ** 0.5
134104
105+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input
135106 def scale_model_input (self , sample : torch .Tensor , timestep : Union [float , torch .Tensor ]) -> torch .Tensor :
136107 """
137108 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -147,18 +118,19 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
147118 `torch.Tensor`:
148119 A scaled input sample.
149120 """
150-
151121 if self .step_index is None :
152122 self ._init_step_index (timestep )
153123
154124 sigma = self .sigmas [self .step_index ]
155125 sample = sample / ((sigma ** 2 + 1 ) ** 0.5 )
126+
156127 self .is_scale_input_called = True
157128 return sample
158129
130+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.set_timesteps
159131 def set_timesteps (
160132 self ,
161- num_inference_steps : int ,
133+ num_inference_steps : int = None ,
162134 device : Union [str , torch .device ] = None ,
163135 timesteps : Optional [List [int ]] = None ,
164136 sigmas : Optional [List [float ]] = None ,
@@ -201,11 +173,11 @@ def set_timesteps(
201173 raise ValueError ("Cannot set `timesteps` with `BetaSigmas`." )
202174 if (
203175 timesteps is not None
204- and self ._schedule .config . get ( " timestep_type" , None ) == "continuous"
176+ and self ._schedule .timestep_type == "continuous"
205177 and self .config .prediction_type == "v_prediction"
206178 ):
207179 raise ValueError (
208- "Cannot set `timesteps` with `config .timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
180+ "Cannot set `timesteps` with `schedule .timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
209181 )
210182
211183 if num_inference_steps is None :
@@ -222,35 +194,11 @@ def set_timesteps(
222194 shift = shift ,
223195 )
224196
225- self .timesteps = timesteps .to (device = device )
226197 self ._step_index = None
227198 self ._begin_index = None
199+ self .timesteps = timesteps .to (device = device )
228200 self .sigmas = sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
229201
230- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
231- def index_for_timestep (self , timestep , schedule_timesteps = None ):
232- if schedule_timesteps is None :
233- schedule_timesteps = self .timesteps
234-
235- indices = (schedule_timesteps == timestep ).nonzero ()
236-
237- # The sigma index that is taken for the **very** first `step`
238- # is always the second index (or the last index if there is only 1)
239- # This way we can ensure we don't accidentally skip a sigma in
240- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
241- pos = 1 if len (indices ) > 1 else 0
242-
243- return indices [pos ].item ()
244-
245- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
246- def _init_step_index (self , timestep ):
247- if self .begin_index is None :
248- if isinstance (timestep , torch .Tensor ):
249- timestep = timestep .to (self .timesteps .device )
250- self ._step_index = self .index_for_timestep (timestep )
251- else :
252- self ._step_index = self ._begin_index
253-
254202 def step (
255203 self ,
256204 model_output : torch .Tensor ,
@@ -352,43 +300,3 @@ def step(
352300 return EulerAncestralDiscreteSchedulerOutput (
353301 prev_sample = prev_sample , pred_original_sample = pred_original_sample
354302 )
355-
356- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
357- def add_noise (
358- self ,
359- original_samples : torch .Tensor ,
360- noise : torch .Tensor ,
361- timesteps : torch .Tensor ,
362- ) -> torch .Tensor :
363- # Make sure sigmas and timesteps have the same device and dtype as original_samples
364- sigmas = self .sigmas .to (device = original_samples .device , dtype = original_samples .dtype )
365- if original_samples .device .type == "mps" and torch .is_floating_point (timesteps ):
366- # mps does not support float64
367- schedule_timesteps = self .timesteps .to (original_samples .device , dtype = torch .float32 )
368- timesteps = timesteps .to (original_samples .device , dtype = torch .float32 )
369- else :
370- schedule_timesteps = self .timesteps .to (original_samples .device )
371- timesteps = timesteps .to (original_samples .device )
372-
373- # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
374- if self .begin_index is None :
375- step_indices = [self .index_for_timestep (t , schedule_timesteps ) for t in timesteps ]
376- elif self .step_index is not None :
377- # add_noise is called after first denoising step (for inpainting)
378- step_indices = [self .step_index ] * timesteps .shape [0 ]
379- else :
380- # add noise is called before first denoising step to create initial latent(img2img)
381- step_indices = [self .begin_index ] * timesteps .shape [0 ]
382-
383- sigma = sigmas [step_indices ].flatten ()
384- while len (sigma .shape ) < len (original_samples .shape ):
385- sigma = sigma .unsqueeze (- 1 )
386-
387- if self ._schedule .__class__ .__name__ == "FlowMatchSchedule" :
388- noisy_samples = (1.0 - sigma ) * original_samples + noise * sigma
389- else :
390- noisy_samples = original_samples + noise * sigma
391- return noisy_samples
392-
393- def __len__ (self ):
394- return self .config .num_train_timesteps
0 commit comments