Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,11 @@ def patch_huggingface_fsdp2_load_full_state_dict():
from fms_acceleration.model_patcher import patch_target_module

patch_target_module(
"accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
"accelerate.accelerator.fsdp2_prepare_model", fsdp2_prepare_model
)
patch_target_module(
"accelerate.utils.fsdp_utils.fsdp2_prepare_model", fsdp2_prepare_model
"accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)


Expand Down Expand Up @@ -734,11 +734,8 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
# Third Party
# pylint: disable=import-outside-toplevel
from accelerate.utils.fsdp_utils import get_parameters_from_modules

# pylint: disable=import-outside-toplevel
# Third Party
from torch.distributed.tensor import distribute_tensor
import torch.distributed as dist

Expand Down Expand Up @@ -774,13 +771,6 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
tensor = tensor.contiguous()
return tensor

ignored_params = {
p.detach()
# pylint: disable=undefined-variable
for p in get_parameters_from_modules(
accelerator.state.fsdp_plugin.ignored_modules, model, accelerator.device
)
}
if accelerator.is_main_process:
for (param_name, full_param), sharded_param in zip(
full_sd.items(), meta_sharded_sd.values()
Expand Down Expand Up @@ -907,7 +897,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if param.__class__.__name__ == "Params4bit":
model_has_params4bit = True
break

if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# pylint: disable=undefined-variable
non_persistent_buffer_fqns = get_non_persistent_buffers(
Expand Down