3636
3737def  set_seed (seed : int ):
3838    """ 
39-     Args: 
4039    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. 
40+ 
41+     Args: 
4142        seed (`int`): The seed to set. 
4243    """ 
4344    random .seed (seed )
@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
194195
195196
196197def  cast_training_params (model : Union [torch .nn .Module , List [torch .nn .Module ]], dtype = torch .float32 ):
198+     """ 
199+     Casts the training parameters of the model to the specified data type. 
200+ 
201+     Args: 
202+         model: The PyTorch model whose parameters will be cast. 
203+         dtype: The data type to which the model parameters will be cast. 
204+     """ 
197205    if  not  isinstance (model , list ):
198206        model  =  [model ]
199207    for  m  in  model :
@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder(
225233def  compute_density_for_timestep_sampling (
226234    weighting_scheme : str , batch_size : int , logit_mean : float  =  None , logit_std : float  =  None , mode_scale : float  =  None 
227235):
228-     """Compute the density for sampling the timesteps when doing SD3 training. 
236+     """ 
237+     Compute the density for sampling the timesteps when doing SD3 training. 
229238
230239    Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. 
231240
@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling(
244253
245254
246255def  compute_loss_weighting_for_sd3 (weighting_scheme : str , sigmas = None ):
247-     """Computes loss weighting scheme for SD3 training. 
256+     """ 
257+     Computes loss weighting scheme for SD3 training. 
248258
249259    Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. 
250260
@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
261271
262272
263273def  free_memory ():
264-     """Runs garbage collection. Then clears the cache of the available accelerator.""" 
274+     """ 
275+     Runs garbage collection. Then clears the cache of the available accelerator. 
276+     """ 
265277    gc .collect ()
266278
267279    if  torch .cuda .is_available ():
@@ -494,7 +506,8 @@ def pin_memory(self) -> None:
494506        self .shadow_params  =  [p .pin_memory () for  p  in  self .shadow_params ]
495507
496508    def  to (self , device = None , dtype = None , non_blocking = False ) ->  None :
497-         r"""Move internal buffers of the ExponentialMovingAverage to `device`. 
509+         r""" 
510+         Move internal buffers of the ExponentialMovingAverage to `device`. 
498511
499512        Args: 
500513            device: like `device` argument to `torch.Tensor.to` 
@@ -528,23 +541,25 @@ def state_dict(self) -> dict:
528541
529542    def  store (self , parameters : Iterable [torch .nn .Parameter ]) ->  None :
530543        r""" 
544+         Saves the current parameters for restoring later. 
545+ 
531546        Args: 
532-         Save the current parameters for restoring later. 
533-             parameters: Iterable of `torch.nn.Parameter`; the parameters to be 
534-                 temporarily stored. 
547+             parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored. 
535548        """ 
536549        self .temp_stored_params  =  [param .detach ().cpu ().clone () for  param  in  parameters ]
537550
538551    def  restore (self , parameters : Iterable [torch .nn .Parameter ]) ->  None :
539552        r""" 
540-         Args: 
541-         Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: 
542-         affecting the original optimization process. Store the parameters before the `copy_to()` method. After 
553+         Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters 
554+         without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After 
543555        validation (or model saving), use this to restore the former parameters. 
556+ 
557+         Args: 
544558            parameters: Iterable of `torch.nn.Parameter`; the parameters to be 
545559                updated with the stored parameters. If `None`, the parameters with which this 
546560                `ExponentialMovingAverage` was initialized will be used. 
547561        """ 
562+ 
548563        if  self .temp_stored_params  is  None :
549564            raise  RuntimeError ("This ExponentialMovingAverage has no `store()`ed weights "  "to `restore()`" )
550565        if  self .foreach :
@@ -560,9 +575,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
560575
561576    def  load_state_dict (self , state_dict : dict ) ->  None :
562577        r""" 
563-         Args: 
564578        Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the 
565579        ema state dict. 
580+ 
581+         Args: 
566582            state_dict (dict): EMA state. Should be an object returned 
567583                from a call to :meth:`state_dict`. 
568584        """ 
0 commit comments