File tree Expand file tree Collapse file tree 4 files changed +16
-12
lines changed Expand file tree Collapse file tree 4 files changed +16
-12
lines changed Original file line number Diff line number Diff line change 7
7
from transformers .pytorch_utils import is_torch_greater_or_equal_than_2_6
8
8
9
9
from .configuration_utils import PretrainedConfig
10
- from .modeling_utils import is_fsdp_enabled
11
10
from .utils import (
11
+ is_fsdp_enabled ,
12
12
is_hqq_available ,
13
13
is_quanto_greater ,
14
14
is_torch_greater_or_equal ,
Original file line number Diff line number Diff line change 110
110
is_bitsandbytes_available ,
111
111
is_flash_attn_2_available ,
112
112
is_flash_attn_3_available ,
113
+ is_fsdp_enabled ,
113
114
is_kernels_available ,
114
115
is_offline_mode ,
115
116
is_optimum_available ,
124
125
is_torch_xla_available ,
125
126
is_torch_xpu_available ,
126
127
logging ,
127
- strtobool ,
128
128
)
129
129
from .utils .generic import _CAN_RECORD_REGISTRY , GeneralInterface , OutputRecorder
130
130
from .utils .hub import create_and_tag_model_card , get_checkpoint_shard_files
182
182
from torch .distributed .tensor import DTensor
183
183
184
184
185
- def is_fsdp_enabled ():
186
- return (
187
- torch .distributed .is_available ()
188
- and torch .distributed .is_initialized ()
189
- and strtobool (os .environ .get ("ACCELERATE_USE_FSDP" , "False" )) == 1
190
- and strtobool (os .environ .get ("FSDP_CPU_RAM_EFFICIENT_LOADING" , "False" )) == 1
191
- )
192
-
193
-
194
185
def is_local_dist_rank_0 ():
195
186
return (
196
187
torch .distributed .is_available ()
Original file line number Diff line number Diff line change 161
161
is_flute_available ,
162
162
is_fp_quant_available ,
163
163
is_fsdp_available ,
164
+ is_fsdp_enabled ,
164
165
is_ftfy_available ,
165
166
is_g2p_en_available ,
166
167
is_galore_torch_available ,
Original file line number Diff line number Diff line change 35
35
36
36
from packaging import version
37
37
38
- from . import logging
38
+ from . import logging , strtobool
39
39
40
40
41
41
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -665,6 +665,18 @@ def is_torch_bf16_available():
665
665
return is_torch_bf16_gpu_available ()
666
666
667
667
668
+ def is_fsdp_enabled ():
669
+ if is_torch_available ():
670
+ import torch
671
+
672
+ return (
673
+ torch .distributed .is_available ()
674
+ and torch .distributed .is_initialized ()
675
+ and strtobool (os .environ .get ("ACCELERATE_USE_FSDP" , "False" )) == 1
676
+ and strtobool (os .environ .get ("FSDP_CPU_RAM_EFFICIENT_LOADING" , "False" )) == 1
677
+ )
678
+
679
+
668
680
@lru_cache
669
681
def is_torch_fp16_available_on_device (device ):
670
682
if not is_torch_available ():
You can’t perform that action at this time.
0 commit comments