@@ -71,14 +71,30 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
7171
7272
7373class ConsistencyDecoderScheduler (SchedulerMixin , ConfigMixin ):
74+ """
75+ A scheduler for the consistency decoder used in Stable Diffusion pipelines.
76+
77+ This scheduler implements a two-step denoising process using consistency models for decoding latent representations
78+ into images.
79+
80+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
81+ methods the library implements for all schedulers such as loading and saving.
82+
83+ Args:
84+ num_train_timesteps (`int`, *optional*, defaults to `1024`):
85+ The number of diffusion steps to train the model.
86+ sigma_data (`float`, *optional*, defaults to `0.5`):
87+ The standard deviation of the data distribution. Used for computing the skip and output scaling factors.
88+ """
89+
7490 order = 1
7591
7692 @register_to_config
7793 def __init__ (
7894 self ,
7995 num_train_timesteps : int = 1024 ,
8096 sigma_data : float = 0.5 ,
81- ):
97+ ) -> None :
8298 betas = betas_for_alpha_bar (num_train_timesteps )
8399
84100 alphas = 1.0 - betas
@@ -98,8 +114,18 @@ def __init__(
98114 def set_timesteps (
99115 self ,
100116 num_inference_steps : Optional [int ] = None ,
101- device : Union [str , torch .device ] = None ,
102- ):
117+ device : Optional [Union [str , torch .device ]] = None ,
118+ ) -> None :
119+ """
120+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
121+
122+ Args:
123+ num_inference_steps (`int`, *optional*):
124+ The number of diffusion steps used when generating samples with a pre-trained model. Currently, only
125+ `2` inference steps are supported.
126+ device (`str` or `torch.device`, *optional*):
127+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
128+ """
103129 if num_inference_steps != 2 :
104130 raise ValueError ("Currently more than 2 inference steps are not supported." )
105131
@@ -111,7 +137,15 @@ def set_timesteps(
111137 self .c_in = self .c_in .to (device )
112138
113139 @property
114- def init_noise_sigma (self ):
140+ def init_noise_sigma (self ) -> torch .Tensor :
141+ """
142+ Return the standard deviation of the initial noise distribution.
143+
144+ Returns:
145+ `torch.Tensor`:
146+ The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first
147+ timestep.
148+ """
115149 return self .sqrt_one_minus_alphas_cumprod [self .timesteps [0 ]]
116150
117151 def scale_model_input (self , sample : torch .Tensor , timestep : Optional [int ] = None ) -> torch .Tensor :
@@ -146,20 +180,20 @@ def step(
146180 Args:
147181 model_output (`torch.Tensor`):
148182 The direct output from the learned diffusion model.
149- timestep (`float`):
183+ timestep (`float` or `torch.Tensor` ):
150184 The current timestep in the diffusion chain.
151185 sample (`torch.Tensor`):
152186 A current instance of a sample created by the diffusion process.
153187 generator (`torch.Generator`, *optional*):
154- A random number generator.
188+ A random number generator for reproducibility .
155189 return_dict (`bool`, *optional*, defaults to `True`):
156190 Whether or not to return a
157- [`~schedulers.scheduling_consistency_models .ConsistencyDecoderSchedulerOutput`] or `tuple`.
191+ [`~schedulers.scheduling_consistency_decoder .ConsistencyDecoderSchedulerOutput`] or `tuple`.
158192
159193 Returns:
160- [`~schedulers.scheduling_consistency_models .ConsistencyDecoderSchedulerOutput`] or `tuple`:
161- If return_dict is `True`,
162- [`~schedulers.scheduling_consistency_models .ConsistencyDecoderSchedulerOutput`] is returned, otherwise
194+ [`~schedulers.scheduling_consistency_decoder .ConsistencyDecoderSchedulerOutput`] or `tuple`:
195+ If ` return_dict` is `True`,
196+ [`~schedulers.scheduling_consistency_decoder .ConsistencyDecoderSchedulerOutput`] is returned, otherwise
163197 a tuple is returned where the first element is the sample tensor.
164198 """
165199 x_0 = self .c_out [timestep ] * model_output + self .c_skip [timestep ] * sample
0 commit comments