Skip to content

Commit 763fd3f

Browse files
committed
Moved the is_fsdp_available function to import utils
1 parent 64b8242 commit 763fd3f

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
88

99
from .configuration_utils import PretrainedConfig
10-
from .modeling_utils import is_fsdp_enabled
1110
from .utils import (
11+
is_fsdp_enabled,
1212
is_hqq_available,
1313
is_quanto_greater,
1414
is_torch_greater_or_equal,

src/transformers/modeling_utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
is_bitsandbytes_available,
111111
is_flash_attn_2_available,
112112
is_flash_attn_3_available,
113+
is_fsdp_enabled,
113114
is_kernels_available,
114115
is_offline_mode,
115116
is_optimum_available,
@@ -124,7 +125,6 @@
124125
is_torch_xla_available,
125126
is_torch_xpu_available,
126127
logging,
127-
strtobool,
128128
)
129129
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
130130
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
@@ -182,15 +182,6 @@
182182
from torch.distributed.tensor import DTensor
183183

184184

185-
def is_fsdp_enabled():
186-
return (
187-
torch.distributed.is_available()
188-
and torch.distributed.is_initialized()
189-
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
190-
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
191-
)
192-
193-
194185
def is_local_dist_rank_0():
195186
return (
196187
torch.distributed.is_available()

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
is_flute_available,
162162
is_fp_quant_available,
163163
is_fsdp_available,
164+
is_fsdp_enabled,
164165
is_ftfy_available,
165166
is_g2p_en_available,
166167
is_galore_torch_available,

src/transformers/utils/import_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from packaging import version
3737

38-
from . import logging
38+
from . import logging, strtobool
3939

4040

4141
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -665,6 +665,18 @@ def is_torch_bf16_available():
665665
return is_torch_bf16_gpu_available()
666666

667667

668+
def is_fsdp_enabled():
669+
if is_torch_available():
670+
import torch
671+
672+
return (
673+
torch.distributed.is_available()
674+
and torch.distributed.is_initialized()
675+
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
676+
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
677+
)
678+
679+
668680
@lru_cache
669681
def is_torch_fp16_available_on_device(device):
670682
if not is_torch_available():

0 commit comments

Comments
 (0)