Skip to content

Commit b1f06b7

Browse files
authored
Improve docstrings and type hints in scheduling_consistency_decoder.py (#12928)
docs: improve docstring scheduling_consistency_decoder.py
1 parent 8600b4c commit b1f06b7

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

src/diffusers/schedulers/scheduling_consistency_decoder.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,30 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
7171

7272

7373
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
74+
"""
75+
A scheduler for the consistency decoder used in Stable Diffusion pipelines.
76+
77+
This scheduler implements a two-step denoising process using consistency models for decoding latent representations
78+
into images.
79+
80+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
81+
methods the library implements for all schedulers such as loading and saving.
82+
83+
Args:
84+
num_train_timesteps (`int`, *optional*, defaults to `1024`):
85+
The number of diffusion steps to train the model.
86+
sigma_data (`float`, *optional*, defaults to `0.5`):
87+
The standard deviation of the data distribution. Used for computing the skip and output scaling factors.
88+
"""
89+
7490
order = 1
7591

7692
@register_to_config
7793
def __init__(
7894
self,
7995
num_train_timesteps: int = 1024,
8096
sigma_data: float = 0.5,
81-
):
97+
) -> None:
8298
betas = betas_for_alpha_bar(num_train_timesteps)
8399

84100
alphas = 1.0 - betas
@@ -98,8 +114,18 @@ def __init__(
98114
def set_timesteps(
99115
self,
100116
num_inference_steps: Optional[int] = None,
101-
device: Union[str, torch.device] = None,
102-
):
117+
device: Optional[Union[str, torch.device]] = None,
118+
) -> None:
119+
"""
120+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
121+
122+
Args:
123+
num_inference_steps (`int`, *optional*):
124+
The number of diffusion steps used when generating samples with a pre-trained model. Currently, only
125+
`2` inference steps are supported.
126+
device (`str` or `torch.device`, *optional*):
127+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
128+
"""
103129
if num_inference_steps != 2:
104130
raise ValueError("Currently more than 2 inference steps are not supported.")
105131

@@ -111,7 +137,15 @@ def set_timesteps(
111137
self.c_in = self.c_in.to(device)
112138

113139
@property
114-
def init_noise_sigma(self):
140+
def init_noise_sigma(self) -> torch.Tensor:
141+
"""
142+
Return the standard deviation of the initial noise distribution.
143+
144+
Returns:
145+
`torch.Tensor`:
146+
The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first
147+
timestep.
148+
"""
115149
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
116150

117151
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
@@ -146,20 +180,20 @@ def step(
146180
Args:
147181
model_output (`torch.Tensor`):
148182
The direct output from the learned diffusion model.
149-
timestep (`float`):
183+
timestep (`float` or `torch.Tensor`):
150184
The current timestep in the diffusion chain.
151185
sample (`torch.Tensor`):
152186
A current instance of a sample created by the diffusion process.
153187
generator (`torch.Generator`, *optional*):
154-
A random number generator.
188+
A random number generator for reproducibility.
155189
return_dict (`bool`, *optional*, defaults to `True`):
156190
Whether or not to return a
157-
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
191+
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`.
158192
159193
Returns:
160-
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
161-
If return_dict is `True`,
162-
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
194+
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`:
195+
If `return_dict` is `True`,
196+
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
163197
a tuple is returned where the first element is the sample tensor.
164198
"""
165199
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample

0 commit comments

Comments
 (0)