3636
3737def set_seed (seed : int ):
3838 """
39- Args:
4039 Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
40+
41+ Args:
4142 seed (`int`): The seed to set.
4243 """
4344 random .seed (seed )
@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
194195
195196
196197def cast_training_params (model : Union [torch .nn .Module , List [torch .nn .Module ]], dtype = torch .float32 ):
198+ """
199+ Casts the training parameters of the model to the specified data type.
200+
201+ Args:
202+ model: The PyTorch model whose parameters will be cast.
203+ dtype: The data type to which the model parameters will be cast.
204+ """
197205 if not isinstance (model , list ):
198206 model = [model ]
199207 for m in model :
@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder(
225233def compute_density_for_timestep_sampling (
226234 weighting_scheme : str , batch_size : int , logit_mean : float = None , logit_std : float = None , mode_scale : float = None
227235):
228- """Compute the density for sampling the timesteps when doing SD3 training.
236+ """
237+ Compute the density for sampling the timesteps when doing SD3 training.
229238
230239 Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
231240
@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling(
244253
245254
246255def compute_loss_weighting_for_sd3 (weighting_scheme : str , sigmas = None ):
247- """Computes loss weighting scheme for SD3 training.
256+ """
257+ Computes loss weighting scheme for SD3 training.
248258
249259 Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
250260
@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
261271
262272
263273def free_memory ():
264- """Runs garbage collection. Then clears the cache of the available accelerator."""
274+ """
275+ Runs garbage collection. Then clears the cache of the available accelerator.
276+ """
265277 gc .collect ()
266278
267279 if torch .cuda .is_available ():
@@ -494,7 +506,8 @@ def pin_memory(self) -> None:
494506 self .shadow_params = [p .pin_memory () for p in self .shadow_params ]
495507
496508 def to (self , device = None , dtype = None , non_blocking = False ) -> None :
497- r"""Move internal buffers of the ExponentialMovingAverage to `device`.
509+ r"""
510+ Move internal buffers of the ExponentialMovingAverage to `device`.
498511
499512 Args:
500513 device: like `device` argument to `torch.Tensor.to`
@@ -528,23 +541,25 @@ def state_dict(self) -> dict:
528541
529542 def store (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
530543 r"""
544+ Saves the current parameters for restoring later.
545+
531546 Args:
532- Save the current parameters for restoring later.
533- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
534- temporarily stored.
547+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
535548 """
536549 self .temp_stored_params = [param .detach ().cpu ().clone () for param in parameters ]
537550
538551 def restore (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
539552 r"""
540- Args:
541- Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
542- affecting the original optimization process. Store the parameters before the `copy_to()` method. After
553+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
554+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
543555 validation (or model saving), use this to restore the former parameters.
556+
557+ Args:
544558 parameters: Iterable of `torch.nn.Parameter`; the parameters to be
545559 updated with the stored parameters. If `None`, the parameters with which this
546560 `ExponentialMovingAverage` was initialized will be used.
547561 """
562+
548563 if self .temp_stored_params is None :
549564 raise RuntimeError ("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`" )
550565 if self .foreach :
@@ -560,9 +575,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
560575
561576 def load_state_dict (self , state_dict : dict ) -> None :
562577 r"""
563- Args:
564578 Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
565579 ema state dict.
580+
581+ Args:
566582 state_dict (dict): EMA state. Should be an object returned
567583 from a call to :meth:`state_dict`.
568584 """
0 commit comments