Skip to content

Commit 9fb6b89

Browse files
authored
Improve docstrings and type hints in scheduling_edm_euler.py (#12871)
* docs: add comprehensive docstrings and refine type hints for EDM scheduler methods and config parameters. * refactor: Add type hints to DPM-Solver scheduler methods.
1 parent 6fb4c99 commit 9fb6b89

File tree

3 files changed

+386
-89
lines changed

3 files changed

+386
-89
lines changed

src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,20 @@ def set_begin_index(self, begin_index: int = 0):
143143
self._begin_index = begin_index
144144

145145
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
146-
def precondition_inputs(self, sample, sigma):
146+
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
147+
"""
148+
Precondition the input sample by scaling it according to the EDM formulation.
149+
150+
Args:
151+
sample (`torch.Tensor`):
152+
The input sample tensor to precondition.
153+
sigma (`float` or `torch.Tensor`):
154+
The current sigma (noise level) value.
155+
156+
Returns:
157+
`torch.Tensor`:
158+
The scaled input sample.
159+
"""
147160
c_in = self._get_conditioning_c_in(sigma)
148161
scaled_sample = sample * c_in
149162
return scaled_sample
@@ -155,7 +168,27 @@ def precondition_noise(self, sigma):
155168
return sigma.atan() / math.pi * 2
156169

157170
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
158-
def precondition_outputs(self, sample, model_output, sigma):
171+
def precondition_outputs(
172+
self,
173+
sample: torch.Tensor,
174+
model_output: torch.Tensor,
175+
sigma: Union[float, torch.Tensor],
176+
) -> torch.Tensor:
177+
"""
178+
Precondition the model outputs according to the EDM formulation.
179+
180+
Args:
181+
sample (`torch.Tensor`):
182+
The input sample tensor.
183+
model_output (`torch.Tensor`):
184+
The direct output from the learned diffusion model.
185+
sigma (`float` or `torch.Tensor`):
186+
The current sigma (noise level) value.
187+
188+
Returns:
189+
`torch.Tensor`:
190+
The denoised sample computed by combining the skip connection and output scaling.
191+
"""
159192
sigma_data = self.config.sigma_data
160193
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
161194

@@ -173,13 +206,13 @@ def precondition_outputs(self, sample, model_output, sigma):
173206
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
174207
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
175208
"""
176-
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
177-
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
209+
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
210+
need to scale the denoising model input depending on the current timestep.
178211
179212
Args:
180213
sample (`torch.Tensor`):
181-
The input sample.
182-
timestep (`int`, *optional*):
214+
The input sample tensor.
215+
timestep (`float` or `torch.Tensor`):
183216
The current timestep in the diffusion chain.
184217
185218
Returns:
@@ -242,8 +275,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
242275
self.noise_sampler = None
243276

244277
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
245-
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
246-
"""Constructs the noise schedule of Karras et al. (2022)."""
278+
def _compute_karras_sigmas(
279+
self,
280+
ramp: torch.Tensor,
281+
sigma_min: Optional[float] = None,
282+
sigma_max: Optional[float] = None,
283+
) -> torch.Tensor:
284+
"""
285+
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
286+
287+
Args:
288+
ramp (`torch.Tensor`):
289+
A tensor of values in [0, 1] representing the interpolation positions.
290+
sigma_min (`float`, *optional*):
291+
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
292+
sigma_max (`float`, *optional*):
293+
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
294+
295+
Returns:
296+
`torch.Tensor`:
297+
The computed Karras sigma schedule.
298+
"""
247299
sigma_min = sigma_min or self.config.sigma_min
248300
sigma_max = sigma_max or self.config.sigma_max
249301

@@ -254,10 +306,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
254306
return sigmas
255307

256308
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
257-
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
258-
"""Implementation closely follows k-diffusion.
259-
309+
def _compute_exponential_sigmas(
310+
self,
311+
ramp: torch.Tensor,
312+
sigma_min: Optional[float] = None,
313+
sigma_max: Optional[float] = None,
314+
) -> torch.Tensor:
315+
"""
316+
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
260317
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
318+
319+
Args:
320+
ramp (`torch.Tensor`):
321+
A tensor of values representing the interpolation positions.
322+
sigma_min (`float`, *optional*):
323+
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
324+
sigma_max (`float`, *optional*):
325+
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
326+
327+
Returns:
328+
`torch.Tensor`:
329+
The computed exponential sigma schedule.
261330
"""
262331
sigma_min = sigma_min or self.config.sigma_min
263332
sigma_max = sigma_max or self.config.sigma_max
@@ -354,7 +423,10 @@ def dpm_solver_first_order_update(
354423
`torch.Tensor`:
355424
The sample tensor at the previous timestep.
356425
"""
357-
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
426+
sigma_t, sigma_s = (
427+
self.sigmas[self.step_index + 1],
428+
self.sigmas[self.step_index],
429+
)
358430
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
359431
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
360432
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -540,7 +612,10 @@ def step(
540612
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
541613
)
542614
self.noise_sampler = BrownianTreeNoiseSampler(
543-
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
615+
model_output,
616+
sigma_min=self.config.sigma_min,
617+
sigma_max=self.config.sigma_max,
618+
seed=seed,
544619
)
545620
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
546621
model_output.device
@@ -612,7 +687,18 @@ def add_noise(
612687
return noisy_samples
613688

614689
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
615-
def _get_conditioning_c_in(self, sigma):
690+
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
691+
"""
692+
Compute the input conditioning factor for the EDM formulation.
693+
694+
Args:
695+
sigma (`float` or `torch.Tensor`):
696+
The current sigma (noise level) value.
697+
698+
Returns:
699+
`float` or `torch.Tensor`:
700+
The input conditioning factor `c_in`.
701+
"""
616702
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
617703
return c_in
618704

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 112 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,37 @@ def set_begin_index(self, begin_index: int = 0):
175175
self._begin_index = begin_index
176176

177177
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
178-
def precondition_inputs(self, sample, sigma):
178+
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
179+
"""
180+
Precondition the input sample by scaling it according to the EDM formulation.
181+
182+
Args:
183+
sample (`torch.Tensor`):
184+
The input sample tensor to precondition.
185+
sigma (`float` or `torch.Tensor`):
186+
The current sigma (noise level) value.
187+
188+
Returns:
189+
`torch.Tensor`:
190+
The scaled input sample.
191+
"""
179192
c_in = self._get_conditioning_c_in(sigma)
180193
scaled_sample = sample * c_in
181194
return scaled_sample
182195

183196
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
184-
def precondition_noise(self, sigma):
197+
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
198+
"""
199+
Precondition the noise level by applying a logarithmic transformation.
200+
201+
Args:
202+
sigma (`float` or `torch.Tensor`):
203+
The sigma (noise level) value to precondition.
204+
205+
Returns:
206+
`torch.Tensor`:
207+
The preconditioned noise value computed as `0.25 * log(sigma)`.
208+
"""
185209
if not isinstance(sigma, torch.Tensor):
186210
sigma = torch.tensor([sigma])
187211

@@ -190,7 +214,27 @@ def precondition_noise(self, sigma):
190214
return c_noise
191215

192216
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
193-
def precondition_outputs(self, sample, model_output, sigma):
217+
def precondition_outputs(
218+
self,
219+
sample: torch.Tensor,
220+
model_output: torch.Tensor,
221+
sigma: Union[float, torch.Tensor],
222+
) -> torch.Tensor:
223+
"""
224+
Precondition the model outputs according to the EDM formulation.
225+
226+
Args:
227+
sample (`torch.Tensor`):
228+
The input sample tensor.
229+
model_output (`torch.Tensor`):
230+
The direct output from the learned diffusion model.
231+
sigma (`float` or `torch.Tensor`):
232+
The current sigma (noise level) value.
233+
234+
Returns:
235+
`torch.Tensor`:
236+
The denoised sample computed by combining the skip connection and output scaling.
237+
"""
194238
sigma_data = self.config.sigma_data
195239
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
196240

@@ -208,13 +252,13 @@ def precondition_outputs(self, sample, model_output, sigma):
208252
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
209253
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
210254
"""
211-
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
212-
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
255+
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
256+
need to scale the denoising model input depending on the current timestep.
213257
214258
Args:
215259
sample (`torch.Tensor`):
216-
The input sample.
217-
timestep (`int`, *optional*):
260+
The input sample tensor.
261+
timestep (`float` or `torch.Tensor`):
218262
The current timestep in the diffusion chain.
219263
220264
Returns:
@@ -274,8 +318,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
274318
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
275319

276320
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
277-
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
278-
"""Constructs the noise schedule of Karras et al. (2022)."""
321+
def _compute_karras_sigmas(
322+
self,
323+
ramp: torch.Tensor,
324+
sigma_min: Optional[float] = None,
325+
sigma_max: Optional[float] = None,
326+
) -> torch.Tensor:
327+
"""
328+
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
329+
330+
Args:
331+
ramp (`torch.Tensor`):
332+
A tensor of values in [0, 1] representing the interpolation positions.
333+
sigma_min (`float`, *optional*):
334+
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
335+
sigma_max (`float`, *optional*):
336+
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
337+
338+
Returns:
339+
`torch.Tensor`:
340+
The computed Karras sigma schedule.
341+
"""
279342
sigma_min = sigma_min or self.config.sigma_min
280343
sigma_max = sigma_max or self.config.sigma_max
281344

@@ -286,10 +349,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
286349
return sigmas
287350

288351
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
289-
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
290-
"""Implementation closely follows k-diffusion.
291-
352+
def _compute_exponential_sigmas(
353+
self,
354+
ramp: torch.Tensor,
355+
sigma_min: Optional[float] = None,
356+
sigma_max: Optional[float] = None,
357+
) -> torch.Tensor:
358+
"""
359+
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
292360
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
361+
362+
Args:
363+
ramp (`torch.Tensor`):
364+
A tensor of values representing the interpolation positions.
365+
sigma_min (`float`, *optional*):
366+
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
367+
sigma_max (`float`, *optional*):
368+
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
369+
370+
Returns:
371+
`torch.Tensor`:
372+
The computed exponential sigma schedule.
293373
"""
294374
sigma_min = sigma_min or self.config.sigma_min
295375
sigma_max = sigma_max or self.config.sigma_max
@@ -433,7 +513,10 @@ def dpm_solver_first_order_update(
433513
`torch.Tensor`:
434514
The sample tensor at the previous timestep.
435515
"""
436-
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
516+
sigma_t, sigma_s = (
517+
self.sigmas[self.step_index + 1],
518+
self.sigmas[self.step_index],
519+
)
437520
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
438521
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
439522
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -684,7 +767,10 @@ def step(
684767

685768
if self.config.algorithm_type == "sde-dpmsolver++":
686769
noise = randn_tensor(
687-
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
770+
model_output.shape,
771+
generator=generator,
772+
device=model_output.device,
773+
dtype=model_output.dtype,
688774
)
689775
else:
690776
noise = None
@@ -757,7 +843,18 @@ def add_noise(
757843
return noisy_samples
758844

759845
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
760-
def _get_conditioning_c_in(self, sigma):
846+
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
847+
"""
848+
Compute the input conditioning factor for the EDM formulation.
849+
850+
Args:
851+
sigma (`float` or `torch.Tensor`):
852+
The current sigma (noise level) value.
853+
854+
Returns:
855+
`float` or `torch.Tensor`:
856+
The input conditioning factor `c_in`.
857+
"""
761858
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
762859
return c_in
763860

0 commit comments

Comments
 (0)