Skip to content

Commit 5c723f3

Browse files
committed
Moved is_fsdp_available to integrations.fsdp
1 parent 763fd3f commit 5c723f3

File tree

6 files changed

+20
-19
lines changed

6 files changed

+20
-19
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
88

99
from .configuration_utils import PretrainedConfig
10+
from .integrations import is_fsdp_enabled
1011
from .utils import (
11-
is_fsdp_enabled,
1212
is_hqq_available,
1313
is_quanto_greater,
1414
is_torch_greater_or_equal,

src/transformers/integrations/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"eetq": ["replace_with_eetq_linear"],
5656
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
5757
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
58-
"fsdp": ["is_fsdp_managed_module"],
58+
"fsdp": ["is_fsdp_managed_module", "is_fsdp_enabled"],
5959
"ggml": [
6060
"GGUF_CONFIG_MAPPING",
6161
"GGUF_TOKENIZER_MAPPING",

src/transformers/integrations/fsdp.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import os
1617
from typing import TYPE_CHECKING
1718

18-
from ..utils import is_torch_available
19+
from ..utils import is_torch_available, strtobool
1920

2021

2122
if TYPE_CHECKING:
@@ -36,3 +37,17 @@ def is_fsdp_managed_module(module: nn.Module) -> bool:
3637
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
3738
module, "_is_fsdp_managed_module", False
3839
)
40+
41+
42+
def is_fsdp_enabled():
43+
if is_torch_available():
44+
import torch
45+
46+
return (
47+
torch.distributed.is_available()
48+
and torch.distributed.is_initialized()
49+
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
50+
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
51+
)
52+
53+
return False

src/transformers/modeling_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from .distributed import DistributedConfig
5555
from .dynamic_module_utils import custom_object_save
5656
from .generation import CompileConfig, GenerationConfig
57-
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
57+
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
5858
from .integrations.accelerate import find_tied_parameters, init_empty_weights
5959
from .integrations.deepspeed import _load_state_dict_into_zero3_model
6060
from .integrations.eager_paged import eager_paged_attention_forward
@@ -110,7 +110,6 @@
110110
is_bitsandbytes_available,
111111
is_flash_attn_2_available,
112112
is_flash_attn_3_available,
113-
is_fsdp_enabled,
114113
is_kernels_available,
115114
is_offline_mode,
116115
is_optimum_available,

src/transformers/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@
161161
is_flute_available,
162162
is_fp_quant_available,
163163
is_fsdp_available,
164-
is_fsdp_enabled,
165164
is_ftfy_available,
166165
is_g2p_en_available,
167166
is_galore_torch_available,

src/transformers/utils/import_utils.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from packaging import version
3737

38-
from . import logging, strtobool
38+
from . import logging
3939

4040

4141
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -665,18 +665,6 @@ def is_torch_bf16_available():
665665
return is_torch_bf16_gpu_available()
666666

667667

668-
def is_fsdp_enabled():
669-
if is_torch_available():
670-
import torch
671-
672-
return (
673-
torch.distributed.is_available()
674-
and torch.distributed.is_initialized()
675-
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
676-
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
677-
)
678-
679-
680668
@lru_cache
681669
def is_torch_fp16_available_on_device(device):
682670
if not is_torch_available():

0 commit comments

Comments
 (0)