2020from ..configuration_utils import ConfigMixin , register_to_config
2121from ..utils import BaseOutput , logging
2222from ..utils .torch_utils import randn_tensor
23- from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
23+ from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SamplingMixin
2424
2525
2626logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -45,7 +45,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
4545 pred_original_sample : Optional [torch .Tensor ] = None
4646
4747
48- class EulerDiscreteScheduler (SchedulerMixin , ConfigMixin ):
48+ class EulerDiscreteScheduler (SchedulerMixin , ConfigMixin , SamplingMixin ):
4949 """
5050 Euler scheduler.
5151
@@ -106,13 +106,6 @@ def __init__(
106106 self .set_schedule (schedule_config )
107107 self .set_sigma_schedule (sigma_schedule_config )
108108
109- # setable values
110- self .num_inference_steps = None
111-
112- self .is_scale_input_called = False
113- self ._step_index = None
114- self ._begin_index = None
115-
116109 @property
117110 def init_noise_sigma (self ):
118111 # standard deviation of the initial noise distribution
@@ -122,31 +115,6 @@ def init_noise_sigma(self):
122115
123116 return (max_sigma ** 2 + 1 ) ** 0.5
124117
125- @property
126- def step_index (self ):
127- """
128- The index counter for current timestep. It will increase 1 after each scheduler step.
129- """
130- return self ._step_index
131-
132- @property
133- def begin_index (self ):
134- """
135- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
136- """
137- return self ._begin_index
138-
139- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
140- def set_begin_index (self , begin_index : int = 0 ):
141- """
142- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
143-
144- Args:
145- begin_index (`int`):
146- The begin index for the scheduler.
147- """
148- self ._begin_index = begin_index
149-
150118 def scale_model_input (self , sample : torch .Tensor , timestep : Union [float , torch .Tensor ]) -> torch .Tensor :
151119 """
152120 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -188,15 +156,6 @@ def set_timesteps(
188156 The number of diffusion steps used when generating samples with a pre-trained model.
189157 device (`str` or `torch.device`, *optional*):
190158 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
191- timesteps (`List[int]`, *optional*):
192- Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
193- based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
194- must be `None`, and `timestep_spacing` attribute will be ignored.
195- sigmas (`List[float]`, *optional*):
196- Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
197- will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
198- `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
199- custom sigmas schedule.
200159 """
201160
202161 if timesteps is not None and sigmas is not None :
@@ -210,7 +169,7 @@ def set_timesteps(
210169 and self ._sigma_schedule is not None
211170 and self ._sigma_schedule .__class__ .__name__ == "KarrasSigmas"
212171 ):
213- raise ValueError ("Cannot set `timesteps` with `config.use_karras_sigmas = True `." )
172+ raise ValueError ("Cannot set `timesteps` with `KarrasSigmas `." )
214173 if (
215174 timesteps is not None
216175 and self ._sigma_schedule is not None
@@ -225,11 +184,11 @@ def set_timesteps(
225184 raise ValueError ("Cannot set `timesteps` with `BetaSigmas`." )
226185 if (
227186 timesteps is not None
228- and self ._schedule .config . get ( " timestep_type" , None ) == "continuous"
187+ and self ._schedule .timestep_type == "continuous"
229188 and self .config .prediction_type == "v_prediction"
230189 ):
231190 raise ValueError (
232- "Cannot set `timesteps` with `config .timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
191+ "Cannot set `timesteps` with `schedule .timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
233192 )
234193
235194 if num_inference_steps is None :
@@ -248,30 +207,8 @@ def set_timesteps(
248207
249208 self ._step_index = None
250209 self ._begin_index = None
251- self .timesteps = timesteps
252- self .sigmas = sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
253-
254- def index_for_timestep (self , timestep , schedule_timesteps = None ):
255- if schedule_timesteps is None :
256- schedule_timesteps = self .timesteps
257-
258- indices = (schedule_timesteps == timestep ).nonzero ()
259-
260- # The sigma index that is taken for the **very** first `step`
261- # is always the second index (or the last index if there is only 1)
262- # This way we can ensure we don't accidentally skip a sigma in
263- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
264- pos = 1 if len (indices ) > 1 else 0
265-
266- return indices [pos ].item ()
267-
268- def _init_step_index (self , timestep ):
269- if self .begin_index is None :
270- if isinstance (timestep , torch .Tensor ):
271- timestep = timestep .to (self .timesteps .device )
272- self ._step_index = self .index_for_timestep (timestep )
273- else :
274- self ._step_index = self ._begin_index
210+ self .timesteps = timesteps .to (device = device )
211+ self .sigmas = sigmas .to ("cpu" )
275212
276213 def step (
277214 self ,
@@ -382,79 +319,3 @@ def step(
382319 )
383320
384321 return EulerDiscreteSchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_original_sample )
385-
386- def add_noise (
387- self ,
388- original_samples : torch .Tensor ,
389- noise : torch .Tensor ,
390- timesteps : torch .Tensor ,
391- ) -> torch .Tensor :
392- # Make sure sigmas and timesteps have the same device and dtype as original_samples
393- sigmas = self .sigmas .to (device = original_samples .device , dtype = original_samples .dtype )
394- if original_samples .device .type == "mps" and torch .is_floating_point (timesteps ):
395- # mps does not support float64
396- schedule_timesteps = self .timesteps .to (original_samples .device , dtype = torch .float32 )
397- timesteps = timesteps .to (original_samples .device , dtype = torch .float32 )
398- else :
399- schedule_timesteps = self .timesteps .to (original_samples .device )
400- timesteps = timesteps .to (original_samples .device )
401-
402- # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
403- if self .begin_index is None :
404- step_indices = [self .index_for_timestep (t , schedule_timesteps ) for t in timesteps ]
405- elif self .step_index is not None :
406- # add_noise is called after first denoising step (for inpainting)
407- step_indices = [self .step_index ] * timesteps .shape [0 ]
408- else :
409- # add noise is called before first denoising step to create initial latent(img2img)
410- step_indices = [self .begin_index ] * timesteps .shape [0 ]
411-
412- sigma = sigmas [step_indices ].flatten ()
413- while len (sigma .shape ) < len (original_samples .shape ):
414- sigma = sigma .unsqueeze (- 1 )
415-
416- if self ._schedule .__class__ .__name__ == "FlowMatchSchedule" :
417- noisy_samples = (1.0 - sigma ) * original_samples + noise * sigma
418- else :
419- noisy_samples = original_samples + noise * sigma
420- return noisy_samples
421-
422- def get_velocity (self , sample : torch .Tensor , noise : torch .Tensor , timesteps : torch .Tensor ) -> torch .Tensor :
423- if (
424- isinstance (timesteps , int )
425- or isinstance (timesteps , torch .IntTensor )
426- or isinstance (timesteps , torch .LongTensor )
427- ):
428- raise ValueError (
429- (
430- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
431- " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
432- " one of the `scheduler.timesteps` as a timestep."
433- ),
434- )
435-
436- if sample .device .type == "mps" and torch .is_floating_point (timesteps ):
437- # mps does not support float64
438- schedule_timesteps = self .timesteps .to (sample .device , dtype = torch .float32 )
439- timesteps = timesteps .to (sample .device , dtype = torch .float32 )
440- else :
441- schedule_timesteps = self .timesteps .to (sample .device )
442- timesteps = timesteps .to (sample .device )
443-
444- step_indices = [self .index_for_timestep (t , schedule_timesteps ) for t in timesteps ]
445- alphas_cumprod = self .alphas_cumprod .to (sample )
446- sqrt_alpha_prod = alphas_cumprod [step_indices ] ** 0.5
447- sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
448- while len (sqrt_alpha_prod .shape ) < len (sample .shape ):
449- sqrt_alpha_prod = sqrt_alpha_prod .unsqueeze (- 1 )
450-
451- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod [step_indices ]) ** 0.5
452- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
453- while len (sqrt_one_minus_alpha_prod .shape ) < len (sample .shape ):
454- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .unsqueeze (- 1 )
455-
456- velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
457- return velocity
458-
459- def __len__ (self ):
460- return self .config .num_train_timesteps
0 commit comments