Skip to content

Commit 647c66a

Browse files
Apply style fixes
1 parent 0052b21 commit 647c66a

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

src/diffusers/training_utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
import re
77
import warnings
88
from contextlib import contextmanager
9-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type
9+
from functools import partial
10+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
1011

1112
import numpy as np
1213
import torch
1314
import torch.distributed as dist
1415
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
1516
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1617
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
17-
from functools import partial
1818

1919
from .models import UNet2DConditionModel
2020
from .pipelines import DiffusionPipeline
@@ -412,21 +412,19 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
412412
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
413413
else:
414414
# FSDP is enabled → use plugin's strategy, or default if None
415-
kwargs["sharding_strategy"] = (
416-
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
417-
)
415+
kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
418416

419417
return kwargs
420418

421419

422420
def wrap_with_fsdp(
423-
model: torch.nn.Module,
424-
device: Union[str, torch.device],
425-
offload: bool = True,
426-
use_orig_params: bool = True,
427-
limit_all_gathers: bool = True,
428-
fsdp_kwargs: Optional[Dict[str, Any]] = None,
429-
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
421+
model: torch.nn.Module,
422+
device: Union[str, torch.device],
423+
offload: bool = True,
424+
use_orig_params: bool = True,
425+
limit_all_gathers: bool = True,
426+
fsdp_kwargs: Optional[Dict[str, Any]] = None,
427+
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
430428
) -> FSDP:
431429
"""
432430
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
@@ -459,7 +457,7 @@ def wrap_with_fsdp(
459457
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
460458
"use_orig_params": use_orig_params,
461459
"limit_all_gathers": limit_all_gathers,
462-
"auto_wrap_policy": auto_wrap_policy
460+
"auto_wrap_policy": auto_wrap_policy,
463461
}
464462

465463
if fsdp_kwargs:

0 commit comments

Comments
 (0)