Skip to content

Commit a88a7b4

Browse files
authored
Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)
* Improve docstrings and type hints in multiple diffusion schedulers * docs: update Imagen Video paper link to Hugging Face Papers.
1 parent c8656ed commit a88a7b4

8 files changed

+329
-52
lines changed

src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,22 @@ def multistep_dpm_solver_second_order_update(
429429
return x_t
430430

431431
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
432-
def index_for_timestep(self, timestep, schedule_timesteps=None):
432+
def index_for_timestep(
433+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
434+
) -> int:
435+
"""
436+
Find the index for a given timestep in the schedule.
437+
438+
Args:
439+
timestep (`int` or `torch.Tensor`):
440+
The timestep for which to find the index.
441+
schedule_timesteps (`torch.Tensor`, *optional*):
442+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
443+
444+
Returns:
445+
`int`:
446+
The index of the timestep in the schedule.
447+
"""
433448
if schedule_timesteps is None:
434449
schedule_timesteps = self.timesteps
435450

@@ -452,6 +467,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
452467
def _init_step_index(self, timestep):
453468
"""
454469
Initialize the step_index counter for the scheduler.
470+
471+
Args:
472+
timestep (`int` or `torch.Tensor`):
473+
The current timestep for which to initialize the step index.
455474
"""
456475

457476
if self.begin_index is None:

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
401401

402402
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
403403
def _sigma_to_alpha_sigma_t(self, sigma):
404+
"""
405+
Convert sigma values to alpha_t and sigma_t values.
406+
407+
Args:
408+
sigma (`torch.Tensor`):
409+
The sigma value(s) to convert.
410+
411+
Returns:
412+
`Tuple[torch.Tensor, torch.Tensor]`:
413+
A tuple containing (alpha_t, sigma_t) values.
414+
"""
404415
if self.config.use_flow_sigmas:
405416
alpha_t = 1 - sigma
406417
sigma_t = sigma
@@ -808,7 +819,22 @@ def ind_fn(t, b, c, d):
808819
raise NotImplementedError("only support log-rho multistep deis now")
809820

810821
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
811-
def index_for_timestep(self, timestep, schedule_timesteps=None):
822+
def index_for_timestep(
823+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
824+
) -> int:
825+
"""
826+
Find the index for a given timestep in the schedule.
827+
828+
Args:
829+
timestep (`int` or `torch.Tensor`):
830+
The timestep for which to find the index.
831+
schedule_timesteps (`torch.Tensor`, *optional*):
832+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
833+
834+
Returns:
835+
`int`:
836+
The index of the timestep in the schedule.
837+
"""
812838
if schedule_timesteps is None:
813839
schedule_timesteps = self.timesteps
814840

@@ -831,6 +857,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
831857
def _init_step_index(self, timestep):
832858
"""
833859
Initialize the step_index counter for the scheduler.
860+
861+
Args:
862+
timestep (`int` or `torch.Tensor`):
863+
The current timestep for which to initialize the step index.
834864
"""
835865

836866
if self.begin_index is None:
@@ -927,6 +957,21 @@ def add_noise(
927957
noise: torch.Tensor,
928958
timesteps: torch.IntTensor,
929959
) -> torch.Tensor:
960+
"""
961+
Add noise to the original samples according to the noise schedule at the specified timesteps.
962+
963+
Args:
964+
original_samples (`torch.Tensor`):
965+
The original samples without noise.
966+
noise (`torch.Tensor`):
967+
The noise to add to the samples.
968+
timesteps (`torch.IntTensor`):
969+
The timesteps at which to add noise to the samples.
970+
971+
Returns:
972+
`torch.Tensor`:
973+
The noisy samples.
974+
"""
930975
# Make sure sigmas and timesteps have the same device and dtype as original_samples
931976
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
932977
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 109 additions & 46 deletions
Large diffs are not rendered by default.

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
413413

414414
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
415415
def _sigma_to_alpha_sigma_t(self, sigma):
416+
"""
417+
Convert sigma values to alpha_t and sigma_t values.
418+
419+
Args:
420+
sigma (`torch.Tensor`):
421+
The sigma value(s) to convert.
422+
423+
Returns:
424+
`Tuple[torch.Tensor, torch.Tensor]`:
425+
A tuple containing (alpha_t, sigma_t) values.
426+
"""
416427
if self.config.use_flow_sigmas:
417428
alpha_t = 1 - sigma
418429
sigma_t = sigma

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
491491

492492
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
493493
def _sigma_to_alpha_sigma_t(self, sigma):
494+
"""
495+
Convert sigma values to alpha_t and sigma_t values.
496+
497+
Args:
498+
sigma (`torch.Tensor`):
499+
The sigma value(s) to convert.
500+
501+
Returns:
502+
`Tuple[torch.Tensor, torch.Tensor]`:
503+
A tuple containing (alpha_t, sigma_t) values.
504+
"""
494505
if self.config.use_flow_sigmas:
495506
alpha_t = 1 - sigma
496507
sigma_t = sigma
@@ -1079,7 +1090,22 @@ def singlestep_dpm_solver_update(
10791090
raise ValueError(f"Order must be 1, 2, 3, got {order}")
10801091

10811092
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
1082-
def index_for_timestep(self, timestep, schedule_timesteps=None):
1093+
def index_for_timestep(
1094+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
1095+
) -> int:
1096+
"""
1097+
Find the index for a given timestep in the schedule.
1098+
1099+
Args:
1100+
timestep (`int` or `torch.Tensor`):
1101+
The timestep for which to find the index.
1102+
schedule_timesteps (`torch.Tensor`, *optional*):
1103+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
1104+
1105+
Returns:
1106+
`int`:
1107+
The index of the timestep in the schedule.
1108+
"""
10831109
if schedule_timesteps is None:
10841110
schedule_timesteps = self.timesteps
10851111

@@ -1102,6 +1128,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
11021128
def _init_step_index(self, timestep):
11031129
"""
11041130
Initialize the step_index counter for the scheduler.
1131+
1132+
Args:
1133+
timestep (`int` or `torch.Tensor`):
1134+
The current timestep for which to initialize the step index.
11051135
"""
11061136

11071137
if self.begin_index is None:
@@ -1204,6 +1234,21 @@ def add_noise(
12041234
noise: torch.Tensor,
12051235
timesteps: torch.IntTensor,
12061236
) -> torch.Tensor:
1237+
"""
1238+
Add noise to the original samples according to the noise schedule at the specified timesteps.
1239+
1240+
Args:
1241+
original_samples (`torch.Tensor`):
1242+
The original samples without noise.
1243+
noise (`torch.Tensor`):
1244+
The noise to add to the samples.
1245+
timesteps (`torch.IntTensor`):
1246+
The timesteps at which to add noise to the samples.
1247+
1248+
Returns:
1249+
`torch.Tensor`:
1250+
The noisy samples.
1251+
"""
12071252
# Make sure sigmas and timesteps have the same device and dtype as original_samples
12081253
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
12091254
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,22 @@ def multistep_dpm_solver_third_order_update(
578578
return x_t
579579

580580
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
581-
def index_for_timestep(self, timestep, schedule_timesteps=None):
581+
def index_for_timestep(
582+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
583+
) -> int:
584+
"""
585+
Find the index for a given timestep in the schedule.
586+
587+
Args:
588+
timestep (`int` or `torch.Tensor`):
589+
The timestep for which to find the index.
590+
schedule_timesteps (`torch.Tensor`, *optional*):
591+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
592+
593+
Returns:
594+
`int`:
595+
The index of the timestep in the schedule.
596+
"""
582597
if schedule_timesteps is None:
583598
schedule_timesteps = self.timesteps
584599

@@ -601,6 +616,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
601616
def _init_step_index(self, timestep):
602617
"""
603618
Initialize the step_index counter for the scheduler.
619+
620+
Args:
621+
timestep (`int` or `torch.Tensor`):
622+
The current timestep for which to initialize the step index.
604623
"""
605624

606625
if self.begin_index is None:

src/diffusers/schedulers/scheduling_sasolver.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
423423

424424
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
425425
def _sigma_to_alpha_sigma_t(self, sigma):
426+
"""
427+
Convert sigma values to alpha_t and sigma_t values.
428+
429+
Args:
430+
sigma (`torch.Tensor`):
431+
The sigma value(s) to convert.
432+
433+
Returns:
434+
`Tuple[torch.Tensor, torch.Tensor]`:
435+
A tuple containing (alpha_t, sigma_t) values.
436+
"""
426437
if self.config.use_flow_sigmas:
427438
alpha_t = 1 - sigma
428439
sigma_t = sigma
@@ -1103,7 +1114,22 @@ def stochastic_adams_moulton_update(
11031114
return x_t
11041115

11051116
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
1106-
def index_for_timestep(self, timestep, schedule_timesteps=None):
1117+
def index_for_timestep(
1118+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
1119+
) -> int:
1120+
"""
1121+
Find the index for a given timestep in the schedule.
1122+
1123+
Args:
1124+
timestep (`int` or `torch.Tensor`):
1125+
The timestep for which to find the index.
1126+
schedule_timesteps (`torch.Tensor`, *optional*):
1127+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
1128+
1129+
Returns:
1130+
`int`:
1131+
The index of the timestep in the schedule.
1132+
"""
11071133
if schedule_timesteps is None:
11081134
schedule_timesteps = self.timesteps
11091135

@@ -1126,6 +1152,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
11261152
def _init_step_index(self, timestep):
11271153
"""
11281154
Initialize the step_index counter for the scheduler.
1155+
1156+
Args:
1157+
timestep (`int` or `torch.Tensor`):
1158+
The current timestep for which to initialize the step index.
11291159
"""
11301160

11311161
if self.begin_index is None:

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
513513

514514
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
515515
def _sigma_to_alpha_sigma_t(self, sigma):
516+
"""
517+
Convert sigma values to alpha_t and sigma_t values.
518+
519+
Args:
520+
sigma (`torch.Tensor`):
521+
The sigma value(s) to convert.
522+
523+
Returns:
524+
`Tuple[torch.Tensor, torch.Tensor]`:
525+
A tuple containing (alpha_t, sigma_t) values.
526+
"""
516527
if self.config.use_flow_sigmas:
517528
alpha_t = 1 - sigma
518529
sigma_t = sigma
@@ -984,7 +995,22 @@ def multistep_uni_c_bh_update(
984995
return x_t
985996

986997
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
987-
def index_for_timestep(self, timestep, schedule_timesteps=None):
998+
def index_for_timestep(
999+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
1000+
) -> int:
1001+
"""
1002+
Find the index for a given timestep in the schedule.
1003+
1004+
Args:
1005+
timestep (`int` or `torch.Tensor`):
1006+
The timestep for which to find the index.
1007+
schedule_timesteps (`torch.Tensor`, *optional*):
1008+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
1009+
1010+
Returns:
1011+
`int`:
1012+
The index of the timestep in the schedule.
1013+
"""
9881014
if schedule_timesteps is None:
9891015
schedule_timesteps = self.timesteps
9901016

@@ -1007,6 +1033,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
10071033
def _init_step_index(self, timestep):
10081034
"""
10091035
Initialize the step_index counter for the scheduler.
1036+
1037+
Args:
1038+
timestep (`int` or `torch.Tensor`):
1039+
The current timestep for which to initialize the step index.
10101040
"""
10111041

10121042
if self.begin_index is None:
@@ -1119,6 +1149,21 @@ def add_noise(
11191149
noise: torch.Tensor,
11201150
timesteps: torch.IntTensor,
11211151
) -> torch.Tensor:
1152+
"""
1153+
Add noise to the original samples according to the noise schedule at the specified timesteps.
1154+
1155+
Args:
1156+
original_samples (`torch.Tensor`):
1157+
The original samples without noise.
1158+
noise (`torch.Tensor`):
1159+
The noise to add to the samples.
1160+
timesteps (`torch.IntTensor`):
1161+
The timesteps at which to add noise to the samples.
1162+
1163+
Returns:
1164+
`torch.Tensor`:
1165+
The noisy samples.
1166+
"""
11221167
# Make sure sigmas and timesteps have the same device and dtype as original_samples
11231168
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
11241169
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):

0 commit comments

Comments
 (0)