|
6 | 6 | import re |
7 | 7 | import warnings |
8 | 8 | from contextlib import contextmanager |
9 | | -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type |
| 9 | +from functools import partial |
| 10 | +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union |
10 | 11 |
|
11 | 12 | import numpy as np |
12 | 13 | import torch |
13 | 14 | import torch.distributed as dist |
14 | 15 | from torch.distributed.fsdp import CPUOffload, ShardingStrategy |
15 | 16 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
16 | 17 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
17 | | -from functools import partial |
18 | 18 |
|
19 | 19 | from .models import UNet2DConditionModel |
20 | 20 | from .pipelines import DiffusionPipeline |
@@ -412,21 +412,19 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: |
412 | 412 | kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD |
413 | 413 | else: |
414 | 414 | # FSDP is enabled → use plugin's strategy, or default if None |
415 | | - kwargs["sharding_strategy"] = ( |
416 | | - fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD |
417 | | - ) |
| 415 | + kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD |
418 | 416 |
|
419 | 417 | return kwargs |
420 | 418 |
|
421 | 419 |
|
422 | 420 | def wrap_with_fsdp( |
423 | | - model: torch.nn.Module, |
424 | | - device: Union[str, torch.device], |
425 | | - offload: bool = True, |
426 | | - use_orig_params: bool = True, |
427 | | - limit_all_gathers: bool = True, |
428 | | - fsdp_kwargs: Optional[Dict[str, Any]] = None, |
429 | | - transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, |
| 421 | + model: torch.nn.Module, |
| 422 | + device: Union[str, torch.device], |
| 423 | + offload: bool = True, |
| 424 | + use_orig_params: bool = True, |
| 425 | + limit_all_gathers: bool = True, |
| 426 | + fsdp_kwargs: Optional[Dict[str, Any]] = None, |
| 427 | + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, |
430 | 428 | ) -> FSDP: |
431 | 429 | """ |
432 | 430 | Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. |
@@ -459,7 +457,7 @@ def wrap_with_fsdp( |
459 | 457 | "cpu_offload": CPUOffload(offload_params=offload) if offload else None, |
460 | 458 | "use_orig_params": use_orig_params, |
461 | 459 | "limit_all_gathers": limit_all_gathers, |
462 | | - "auto_wrap_policy": auto_wrap_policy |
| 460 | + "auto_wrap_policy": auto_wrap_policy, |
463 | 461 | } |
464 | 462 |
|
465 | 463 | if fsdp_kwargs: |
|
0 commit comments