diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index d1ee538c..87efabe7 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -34,6 +34,8 @@ jobs: - "online-data-mixing" steps: + - name: Delete huge unnecessary tools folder + run: rm -rf /opt/hostedtoolcache - uses: actions/checkout@v4 - name: Set up Python 3.11 uses: actions/setup-python@v4 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 3269e1f2..1f14bcfc 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -735,6 +735,9 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic full_sd (`dict`): The full state dict to load, can only be on rank 0 """ # Third Party + # pylint: disable=import-outside-toplevel + from accelerate.utils.fsdp_utils import get_parameters_from_modules + # pylint: disable=import-outside-toplevel from torch.distributed.tensor import distribute_tensor import torch.distributed as dist @@ -847,7 +850,20 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: Returns: `torch.nn.Module`: Prepared model """ + # Standard + # pylint: disable=import-outside-toplevel + import copy + import warnings + # Third Party + # pylint: disable=import-outside-toplevel + from accelerate.utils.fsdp_utils import ( + fsdp2_prepare_auto_wrap_policy, + get_parameters_from_modules, + ) + from accelerate.utils.modeling import get_non_persistent_buffers + from accelerate.utils.other import get_module_children_bottom_up, is_compiled_module + # pylint: disable=import-outside-toplevel from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard diff --git a/plugins/framework/pyproject.toml b/plugins/framework/pyproject.toml index d8f570b1..0c43df11 100644 --- a/plugins/framework/pyproject.toml +++ b/plugins/framework/pyproject.toml @@ -25,12 +25,15 @@ dependencies = [ "numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3 "torch>2.2", "peft>=0.15.0", - "accelerate", + "accelerate @ git+https://github.com/huggingface/accelerate.git@5998f8625b8dfde9253c241233ff13bc2c18635d", "pandas", ] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration"] +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel.sources] "src" = ""