diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 1f14bcfc..024fcc16 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -257,11 +257,11 @@ def patch_huggingface_fsdp2_load_full_state_dict(): from fms_acceleration.model_patcher import patch_target_module patch_target_module( - "accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict", - fsdp2_load_full_state_dict, + "accelerate.accelerator.fsdp2_prepare_model", fsdp2_prepare_model ) patch_target_module( - "accelerate.utils.fsdp_utils.fsdp2_prepare_model", fsdp2_prepare_model + "accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict", + fsdp2_load_full_state_dict, ) @@ -734,11 +734,8 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic or a VRAM spike can occur full_sd (`dict`): The full state dict to load, can only be on rank 0 """ - # Third Party - # pylint: disable=import-outside-toplevel - from accelerate.utils.fsdp_utils import get_parameters_from_modules - # pylint: disable=import-outside-toplevel + # Third Party from torch.distributed.tensor import distribute_tensor import torch.distributed as dist @@ -774,13 +771,6 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): tensor = tensor.contiguous() return tensor - ignored_params = { - p.detach() - # pylint: disable=undefined-variable - for p in get_parameters_from_modules( - accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device - ) - } if accelerator.is_main_process: for (param_name, full_param), sharded_param in zip( full_sd.items(), meta_sharded_sd.values() @@ -907,7 +897,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if param.__class__.__name__ == "Params4bit": model_has_params4bit = True break - if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: # pylint: disable=undefined-variable non_persistent_buffer_fqns = get_non_persistent_buffers(