Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:
- "online-data-mixing"

steps:
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,9 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
# Third Party
# pylint: disable=import-outside-toplevel
from accelerate.utils.fsdp_utils import get_parameters_from_modules

# pylint: disable=import-outside-toplevel
from torch.distributed.tensor import distribute_tensor
import torch.distributed as dist
Expand Down Expand Up @@ -847,7 +850,20 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
Returns:
`torch.nn.Module`: Prepared model
"""
# Standard
# pylint: disable=import-outside-toplevel
import copy
import warnings

# Third Party
# pylint: disable=import-outside-toplevel
from accelerate.utils.fsdp_utils import (
fsdp2_prepare_auto_wrap_policy,
get_parameters_from_modules,
)
from accelerate.utils.modeling import get_non_persistent_buffers
from accelerate.utils.other import get_module_children_bottom_up, is_compiled_module

# pylint: disable=import-outside-toplevel
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard

Expand Down
5 changes: 4 additions & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ dependencies = [
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
"torch>2.2",
"peft>=0.15.0",
"accelerate",
"accelerate @ git+https://github.com/huggingface/accelerate.git@5998f8625b8dfde9253c241233ff13bc2c18635d",
"pandas",
]

[tool.hatch.build.targets.wheel]
only-include = ["src/fms_acceleration"]

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel.sources]
"src" = ""