Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
# pylint: disable=import-outside-toplevel
from torch.distributed.tensor import distribute_tensor
import torch.distributed as dist
from accelerate.utils.fsdp_utils import get_parameters_from_modules

# Model was previously copied to meta device
meta_sharded_sd = model.state_dict()
Expand Down Expand Up @@ -850,7 +851,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
# Third Party
# pylint: disable=import-outside-toplevel
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard

from accelerate.utils.fsdp_utils import get_parameters_from_modules, fsdp2_prepare_auto_wrap_policy
from accelerate.utils.other import is_compiled_module, get_module_children_bottom_up
from accelerate.utils.modeling import get_non_persistent_buffers
import copy
import warnings
is_type_fsdp = isinstance(model, FSDPModule) or (
# pylint: disable=undefined-variable
is_compiled_module(model)
Expand Down
Loading