Skip to content

Commit 429d464

Browse files
authored
Merge branch 'main' into fix-test-name-cogi2v
2 parents d5fc5c1 + fff4be8 commit 429d464

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

examples/community/hd_painter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -898,13 +898,16 @@ class GaussianSmoothing(nn.Module):
898898
Apply gaussian smoothing on a
899899
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
900900
in the input using a depthwise convolution.
901-
Arguments:
902-
channels (int, sequence): Number of channels of the input tensors. Output will
903-
have this number of channels as well.
904-
kernel_size (int, sequence): Size of the gaussian kernel.
905-
sigma (float, sequence): Standard deviation of the gaussian kernel.
906-
dim (int, optional): The number of dimensions of the data.
907-
Default value is 2 (spatial).
901+
902+
Args:
903+
channels (`int` or `sequence`):
904+
Number of channels of the input tensors. The output will have this number of channels as well.
905+
kernel_size (`int` or `sequence`):
906+
Size of the Gaussian kernel.
907+
sigma (`float` or `sequence`):
908+
Standard deviation of the Gaussian kernel.
909+
dim (`int`, *optional*, defaults to `2`):
910+
The number of dimensions of the data. Default is 2 (spatial dimensions).
908911
"""
909912

910913
def __init__(self, channels, kernel_size, sigma, dim=2):
@@ -944,10 +947,14 @@ def __init__(self, channels, kernel_size, sigma, dim=2):
944947
def forward(self, input):
945948
"""
946949
Apply gaussian filter to input.
947-
Arguments:
948-
input (torch.Tensor): Input to apply gaussian filter on.
950+
951+
Args:
952+
input (`torch.Tensor` of shape `(N, C, H, W)`):
953+
Input to apply Gaussian filter on.
954+
949955
Returns:
950-
filtered (torch.Tensor): Filtered output.
956+
`torch.Tensor`:
957+
The filtered output tensor with the same shape as the input.
951958
"""
952959
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
953960

src/diffusers/training_utils.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636

3737
def set_seed(seed: int):
3838
"""
39-
Args:
4039
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
40+
41+
Args:
4142
seed (`int`): The seed to set.
4243
"""
4344
random.seed(seed)
@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
194195

195196

196197
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
198+
"""
199+
Casts the training parameters of the model to the specified data type.
200+
201+
Args:
202+
model: The PyTorch model whose parameters will be cast.
203+
dtype: The data type to which the model parameters will be cast.
204+
"""
197205
if not isinstance(model, list):
198206
model = [model]
199207
for m in model:
@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder(
225233
def compute_density_for_timestep_sampling(
226234
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
227235
):
228-
"""Compute the density for sampling the timesteps when doing SD3 training.
236+
"""
237+
Compute the density for sampling the timesteps when doing SD3 training.
229238
230239
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
231240
@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling(
244253

245254

246255
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
247-
"""Computes loss weighting scheme for SD3 training.
256+
"""
257+
Computes loss weighting scheme for SD3 training.
248258
249259
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
250260
@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
261271

262272

263273
def free_memory():
264-
"""Runs garbage collection. Then clears the cache of the available accelerator."""
274+
"""
275+
Runs garbage collection. Then clears the cache of the available accelerator.
276+
"""
265277
gc.collect()
266278

267279
if torch.cuda.is_available():
@@ -494,7 +506,8 @@ def pin_memory(self) -> None:
494506
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
495507

496508
def to(self, device=None, dtype=None, non_blocking=False) -> None:
497-
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
509+
r"""
510+
Move internal buffers of the ExponentialMovingAverage to `device`.
498511
499512
Args:
500513
device: like `device` argument to `torch.Tensor.to`
@@ -528,23 +541,25 @@ def state_dict(self) -> dict:
528541

529542
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
530543
r"""
544+
Saves the current parameters for restoring later.
545+
531546
Args:
532-
Save the current parameters for restoring later.
533-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
534-
temporarily stored.
547+
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
535548
"""
536549
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
537550

538551
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
539552
r"""
540-
Args:
541-
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
542-
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
553+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
554+
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
543555
validation (or model saving), use this to restore the former parameters.
556+
557+
Args:
544558
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
545559
updated with the stored parameters. If `None`, the parameters with which this
546560
`ExponentialMovingAverage` was initialized will be used.
547561
"""
562+
548563
if self.temp_stored_params is None:
549564
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
550565
if self.foreach:
@@ -560,9 +575,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
560575

561576
def load_state_dict(self, state_dict: dict) -> None:
562577
r"""
563-
Args:
564578
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
565579
ema state dict.
580+
581+
Args:
566582
state_dict (dict): EMA state. Should be an object returned
567583
from a call to :meth:`state_dict`.
568584
"""

0 commit comments

Comments
 (0)