@@ -175,13 +175,37 @@ def set_begin_index(self, begin_index: int = 0):
175175 self ._begin_index = begin_index
176176
177177 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
178- def precondition_inputs (self , sample , sigma ):
178+ def precondition_inputs (self , sample : torch .Tensor , sigma : Union [float , torch .Tensor ]) -> torch .Tensor :
179+ """
180+ Precondition the input sample by scaling it according to the EDM formulation.
181+
182+ Args:
183+ sample (`torch.Tensor`):
184+ The input sample tensor to precondition.
185+ sigma (`float` or `torch.Tensor`):
186+ The current sigma (noise level) value.
187+
188+ Returns:
189+ `torch.Tensor`:
190+ The scaled input sample.
191+ """
179192 c_in = self ._get_conditioning_c_in (sigma )
180193 scaled_sample = sample * c_in
181194 return scaled_sample
182195
183196 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
184- def precondition_noise (self , sigma ):
197+ def precondition_noise (self , sigma : Union [float , torch .Tensor ]) -> torch .Tensor :
198+ """
199+ Precondition the noise level by applying a logarithmic transformation.
200+
201+ Args:
202+ sigma (`float` or `torch.Tensor`):
203+ The sigma (noise level) value to precondition.
204+
205+ Returns:
206+ `torch.Tensor`:
207+ The preconditioned noise value computed as `0.25 * log(sigma)`.
208+ """
185209 if not isinstance (sigma , torch .Tensor ):
186210 sigma = torch .tensor ([sigma ])
187211
@@ -190,7 +214,27 @@ def precondition_noise(self, sigma):
190214 return c_noise
191215
192216 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
193- def precondition_outputs (self , sample , model_output , sigma ):
217+ def precondition_outputs (
218+ self ,
219+ sample : torch .Tensor ,
220+ model_output : torch .Tensor ,
221+ sigma : Union [float , torch .Tensor ],
222+ ) -> torch .Tensor :
223+ """
224+ Precondition the model outputs according to the EDM formulation.
225+
226+ Args:
227+ sample (`torch.Tensor`):
228+ The input sample tensor.
229+ model_output (`torch.Tensor`):
230+ The direct output from the learned diffusion model.
231+ sigma (`float` or `torch.Tensor`):
232+ The current sigma (noise level) value.
233+
234+ Returns:
235+ `torch.Tensor`:
236+ The denoised sample computed by combining the skip connection and output scaling.
237+ """
194238 sigma_data = self .config .sigma_data
195239 c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
196240
@@ -208,13 +252,13 @@ def precondition_outputs(self, sample, model_output, sigma):
208252 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
209253 def scale_model_input (self , sample : torch .Tensor , timestep : Union [float , torch .Tensor ]) -> torch .Tensor :
210254 """
211- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
212- current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm .
255+ Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
256+ need to scale the denoising model input depending on the current timestep .
213257
214258 Args:
215259 sample (`torch.Tensor`):
216- The input sample.
217- timestep (`int`, *optional* ):
260+ The input sample tensor .
261+ timestep (`float` or `torch.Tensor` ):
218262 The current timestep in the diffusion chain.
219263
220264 Returns:
@@ -274,8 +318,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
274318 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
275319
276320 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
277- def _compute_karras_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .Tensor :
278- """Constructs the noise schedule of Karras et al. (2022)."""
321+ def _compute_karras_sigmas (
322+ self ,
323+ ramp : torch .Tensor ,
324+ sigma_min : Optional [float ] = None ,
325+ sigma_max : Optional [float ] = None ,
326+ ) -> torch .Tensor :
327+ """
328+ Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
329+
330+ Args:
331+ ramp (`torch.Tensor`):
332+ A tensor of values in [0, 1] representing the interpolation positions.
333+ sigma_min (`float`, *optional*):
334+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
335+ sigma_max (`float`, *optional*):
336+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
337+
338+ Returns:
339+ `torch.Tensor`:
340+ The computed Karras sigma schedule.
341+ """
279342 sigma_min = sigma_min or self .config .sigma_min
280343 sigma_max = sigma_max or self .config .sigma_max
281344
@@ -286,10 +349,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
286349 return sigmas
287350
288351 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
289- def _compute_exponential_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .Tensor :
290- """Implementation closely follows k-diffusion.
291-
352+ def _compute_exponential_sigmas (
353+ self ,
354+ ramp : torch .Tensor ,
355+ sigma_min : Optional [float ] = None ,
356+ sigma_max : Optional [float ] = None ,
357+ ) -> torch .Tensor :
358+ """
359+ Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
292360 https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
361+
362+ Args:
363+ ramp (`torch.Tensor`):
364+ A tensor of values representing the interpolation positions.
365+ sigma_min (`float`, *optional*):
366+ Minimum sigma value. If `None`, uses `self.config.sigma_min`.
367+ sigma_max (`float`, *optional*):
368+ Maximum sigma value. If `None`, uses `self.config.sigma_max`.
369+
370+ Returns:
371+ `torch.Tensor`:
372+ The computed exponential sigma schedule.
293373 """
294374 sigma_min = sigma_min or self .config .sigma_min
295375 sigma_max = sigma_max or self .config .sigma_max
@@ -433,7 +513,10 @@ def dpm_solver_first_order_update(
433513 `torch.Tensor`:
434514 The sample tensor at the previous timestep.
435515 """
436- sigma_t , sigma_s = self .sigmas [self .step_index + 1 ], self .sigmas [self .step_index ]
516+ sigma_t , sigma_s = (
517+ self .sigmas [self .step_index + 1 ],
518+ self .sigmas [self .step_index ],
519+ )
437520 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t )
438521 alpha_s , sigma_s = self ._sigma_to_alpha_sigma_t (sigma_s )
439522 lambda_t = torch .log (alpha_t ) - torch .log (sigma_t )
@@ -684,7 +767,10 @@ def step(
684767
685768 if self .config .algorithm_type == "sde-dpmsolver++" :
686769 noise = randn_tensor (
687- model_output .shape , generator = generator , device = model_output .device , dtype = model_output .dtype
770+ model_output .shape ,
771+ generator = generator ,
772+ device = model_output .device ,
773+ dtype = model_output .dtype ,
688774 )
689775 else :
690776 noise = None
@@ -757,7 +843,18 @@ def add_noise(
757843 return noisy_samples
758844
759845 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
760- def _get_conditioning_c_in (self , sigma ):
846+ def _get_conditioning_c_in (self , sigma : Union [float , torch .Tensor ]) -> Union [float , torch .Tensor ]:
847+ """
848+ Compute the input conditioning factor for the EDM formulation.
849+
850+ Args:
851+ sigma (`float` or `torch.Tensor`):
852+ The current sigma (noise level) value.
853+
854+ Returns:
855+ `float` or `torch.Tensor`:
856+ The input conditioning factor `c_in`.
857+ """
761858 c_in = 1 / ((sigma ** 2 + self .config .sigma_data ** 2 ) ** 0.5 )
762859 return c_in
763860
0 commit comments