Skip to content

Commit d6c986c

Browse files
authored
Bunch of FSDP improvements (#3671)
* Feat: split tests * Feat: finito * Fix * Final, tests pass
1 parent 1ac8643 commit d6c986c

File tree

5 files changed

+61
-52
lines changed

5 files changed

+61
-52
lines changed

src/accelerate/accelerator.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
FSDP2_PYTORCH_VERSION,
126126
FSDP_PYTORCH_VERSION,
127127
PROFILE_PATTERN_NAME,
128+
SCALER_NAME,
128129
)
129130
from .utils.modeling import get_state_dict_offloaded_model
130131
from .utils.other import compile_regions, compile_regions_deepspeed, is_compiled_module
@@ -3521,6 +3522,21 @@ def _inner(folder):
35213522
else:
35223523
models.append(model)
35233524

3525+
# We need to load the scaler state before the optimizer for FSDP2
3526+
# (`torch.distributed.checkpoint.set_optimizer_state_dict`) which we use to set the state of the optimizer calls `optimizer.step` on
3527+
# a dummy tensor, but since the scaler is not initialized, it will raise an error (the scaler exists but its `_scale` is None)
3528+
scaler = None
3529+
if self.scaler is not None and self.is_fsdp2:
3530+
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
3531+
scaler_state = torch.load(input_scaler_file)
3532+
self.scaler.load_state_dict(scaler_state)
3533+
# We also need to call the `_lazy_init_scale_growth_tracker` to initialize the scaler, as it would else be called
3534+
# on the first call to scale
3535+
self.scaler._lazy_init_scale_growth_tracker(self.scaler._device)
3536+
logger.info("GradScaler state loaded successfully")
3537+
else:
3538+
scaler = self.scaler
3539+
35243540
# Load the optimizers taking care of FSDP and DeepSpeed nuances
35253541
optimizers = []
35263542
if self.distributed_type == DistributedType.FSDP:
@@ -3569,7 +3585,7 @@ def _inner(folder):
35693585
schedulers,
35703586
dataloaders,
35713587
self.state.process_index,
3572-
self.scaler,
3588+
scaler,
35733589
map_location,
35743590
load_kwargs,
35753591
**load_model_func_kwargs,

src/accelerate/state.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,6 @@ def __init__(self, cpu: bool = False, **kwargs):
213213
if self.backend == "tccl":
214214
local_rank = os.environ.get("LOCAL_RANK", -1)
215215
torch.sdaa.set_device(f"sdaa:{local_rank}")
216-
if (
217-
self.backend == "nccl"
218-
and os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
219-
and os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true"
220-
):
221-
self.backend = "cuda:nccl,cpu:gloo"
222216
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
223217
# We need to flag to `use_deepspeed` to be True to override `distributed_type` later
224218
use_deepspeed = True
@@ -230,6 +224,15 @@ def __init__(self, cpu: bool = False, **kwargs):
230224
if self.backend == "tccl":
231225
local_rank = os.environ.get("LOCAL_RANK", -1)
232226
torch.sdaa.set_device(f"sdaa:{local_rank}")
227+
if (
228+
self.backend == "nccl"
229+
and os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
230+
and (
231+
os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true"
232+
or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
233+
)
234+
):
235+
self.backend = "cuda:nccl,cpu:gloo"
233236
torch.distributed.init_process_group(backend=self.backend, **kwargs)
234237

235238
# XPU and CPU require special env configs to be set

src/accelerate/test_utils/testing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ def require_fp8(test_case):
250250
return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case)
251251

252252

253+
def require_fsdp2(test_case):
254+
return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case)
255+
256+
253257
def require_mlu(test_case):
254258
"""
255259
Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.

src/accelerate/utils/fsdp_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,9 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
179179
else nullcontext()
180180
)
181181
sd_options = _prepare_sd_options(fsdp_plugin)
182-
183182
with ctx:
184183
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
185-
if type(model) is not FSDP and accelerator.process_index != 0:
184+
if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:
186185
if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:
187186
raise ValueError(
188187
"Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
@@ -192,7 +191,12 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
192191
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
193192
input_model_file = os.path.join(input_dir, weights_name)
194193
logger.info(f"Loading model from {input_model_file}")
195-
state_dict = torch.load(input_model_file, weights_only=True)
194+
# we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`
195+
load_model = not accelerator.is_fsdp2 or accelerator.is_main_process
196+
if load_model:
197+
state_dict = torch.load(input_model_file, weights_only=True)
198+
else:
199+
state_dict = {}
196200
logger.info(f"Model loaded from {input_model_file}")
197201
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
198202
weights_name = (

tests/fsdp/test_fsdp.py

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_launch_command,
3030
path_in_accelerate_package,
3131
require_fp16,
32+
require_fsdp2,
3233
require_multi_device,
3334
require_non_cpu,
3435
require_non_torch_xla,
@@ -37,7 +38,6 @@
3738
)
3839
from accelerate.utils import is_bf16_available, is_fp16_available, is_hpu_available, patch_environment, set_seed
3940
from accelerate.utils.constants import (
40-
FSDP2_PYTORCH_VERSION,
4141
FSDP2_STATE_DICT_TYPE,
4242
FSDP_AUTO_WRAP_POLICY,
4343
FSDP_BACKWARD_PREFETCH,
@@ -46,7 +46,6 @@
4646
)
4747
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
4848
from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading
49-
from accelerate.utils.versions import is_torch_version
5049

5150

5251
set_seed(42)
@@ -63,10 +62,6 @@
6362
if is_bf16_available():
6463
dtypes.append(BF16)
6564

66-
FSDP_VERSIONS = [1]
67-
if is_torch_version(">=", FSDP2_PYTORCH_VERSION):
68-
FSDP_VERSIONS.append(2)
69-
7065

7166
@require_non_cpu
7267
@require_non_torch_xla
@@ -90,24 +85,7 @@ def setUp(self):
9085
2: self.fsdp2_env,
9186
}
9287

93-
def run(self, result=None):
94-
"""Override run to get the current test name and format failures to include FSDP version."""
95-
test_method = getattr(self, self._testMethodName)
96-
orig_test_method = test_method
97-
98-
def test_wrapper(*args, **kwargs):
99-
for fsdp_version in FSDP_VERSIONS:
100-
try:
101-
self.current_fsdp_version = fsdp_version
102-
return orig_test_method(*args, **kwargs)
103-
except Exception as e:
104-
raise type(e)(f"FSDP version {fsdp_version}: {str(e)}") from e
105-
106-
setattr(self, self._testMethodName, test_wrapper)
107-
try:
108-
return super().run(result)
109-
finally:
110-
setattr(self, self._testMethodName, orig_test_method)
88+
self.current_fsdp_version = 1
11189

11290
def test_sharding_strategy(self):
11391
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
@@ -421,6 +399,15 @@ def test_cpu_ram_efficient_loading(self):
421399
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
422400

423401

402+
@require_fsdp2
403+
@require_non_cpu
404+
@require_non_torch_xla
405+
class FSDP2PluginIntegration(FSDPPluginIntegration):
406+
def setUp(self):
407+
super().setUp()
408+
self.current_fsdp_version = 2
409+
410+
424411
@run_first
425412
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
426413
@require_non_torch_xla
@@ -462,24 +449,7 @@ def setUp(self):
462449
self.n_train = 160
463450
self.n_val = 160
464451

465-
def run(self, result=None):
466-
"""Override run to get the current test name and format failures to include FSDP version."""
467-
test_method = getattr(self, self._testMethodName)
468-
orig_test_method = test_method
469-
470-
def test_wrapper(*args, **kwargs):
471-
for fsdp_version in FSDP_VERSIONS:
472-
try:
473-
self.current_fsdp_version = fsdp_version
474-
return orig_test_method(*args, **kwargs)
475-
except Exception as e:
476-
raise type(e)(f"FSDP version {fsdp_version}: {str(e)}") from e
477-
478-
setattr(self, self._testMethodName, test_wrapper)
479-
try:
480-
return super().run(result)
481-
finally:
482-
setattr(self, self._testMethodName, orig_test_method)
452+
self.current_fsdp_version = 1
483453

484454
@require_fp16
485455
def test_performance(self):
@@ -633,3 +603,15 @@ def test_peak_memory_usage(self):
633603
)
634604
with patch_environment(omp_num_threads=1):
635605
execute_subprocess_async(cmd_config)
606+
607+
608+
@require_fsdp2
609+
@run_first
610+
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
611+
@require_non_torch_xla
612+
@require_multi_device
613+
@slow
614+
class FSDP2IntegrationTest(FSDPIntegrationTest):
615+
def setUp(self):
616+
super().setUp()
617+
self.current_fsdp_version = 2

0 commit comments

Comments
 (0)