diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 3269e1f2..7330b1fc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -738,6 +738,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic # pylint: disable=import-outside-toplevel from torch.distributed.tensor import distribute_tensor import torch.distributed as dist + from accelerate.utils.fsdp_utils import get_parameters_from_modules # Model was previously copied to meta device meta_sharded_sd = model.state_dict() @@ -850,7 +851,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: # Third Party # pylint: disable=import-outside-toplevel from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard - + from accelerate.utils.fsdp_utils import get_parameters_from_modules, fsdp2_prepare_auto_wrap_policy + from accelerate.utils.other import is_compiled_module, get_module_children_bottom_up + from accelerate.utils.modeling import get_non_persistent_buffers + import copy + import warnings is_type_fsdp = isinstance(model, FSDPModule) or ( # pylint: disable=undefined-variable is_compiled_module(model)