Skip to content

Commit e4893b9

Browse files
awaelchlilantiga
authored andcommitted
Refined FSDP saving logic and error messaging when path exists (#18884)
(cherry picked from commit e66be67)
1 parent c148282 commit e4893b9

File tree

6 files changed

+126
-15
lines changed

6 files changed

+126
-15
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Fixed false-positive warnings about method calls on the Fabric-wrapped module ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))
3434

3535

36+
- Refined the FSDP saving logic and error messaging when path exists ([#18884](https://github.com/Lightning-AI/lightning/pull/18884))
37+
38+
3639
## [2.1.0] - 2023-10-11
3740

3841
### Added

src/lightning/fabric/strategies/fsdp.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
14+
import shutil
1515
from contextlib import ExitStack
1616
from datetime import timedelta
1717
from functools import partial
@@ -432,8 +432,8 @@ def save_checkpoint(
432432

433433
# broadcast the path from rank 0 to ensure all the states are saved in a common path
434434
path = Path(self.broadcast(path))
435-
if path.is_dir() and os.listdir(path):
436-
raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}")
435+
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
436+
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
437437

438438
from torch.distributed.checkpoint import FileSystemWriter, save_state_dict
439439
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -454,7 +454,10 @@ def save_checkpoint(
454454
module = modules[0]
455455

456456
if self._state_dict_type == "sharded":
457+
if path.is_file():
458+
path.unlink()
457459
path.mkdir(parents=True, exist_ok=True)
460+
458461
state_dict_ctx = _get_sharded_state_dict_context(module)
459462

460463
# replace the modules and optimizer objects in the state with their local state dict
@@ -483,6 +486,9 @@ def save_checkpoint(
483486
torch.save(metadata, path / _METADATA_FILENAME)
484487

485488
elif self._state_dict_type == "full":
489+
if _is_sharded_checkpoint(path):
490+
shutil.rmtree(path)
491+
486492
state_dict_ctx = _get_full_state_dict_context(module, world_size=self.world_size)
487493
full_state: Dict[str, Any] = {}
488494
with state_dict_ctx:

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Fixed an issue saving the `last.ckpt` file when using `ModelCheckpoint` on a remote filesystem and no logger is used ([#18867](https://github.com/Lightning-AI/lightning/issues/18867))
3939

4040

41+
- Refined the FSDP saving logic and error messaging when path exists ([#18884](https://github.com/Lightning-AI/lightning/pull/18884))
42+
43+
4144
## [2.1.0] - 2023-10-11
4245

4346
### Added

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import os
15+
import shutil
1616
from contextlib import contextmanager, nullcontext
1717
from datetime import timedelta
1818
from pathlib import Path
@@ -522,12 +522,14 @@ def save_checkpoint(
522522
)
523523

524524
path = Path(self.broadcast(filepath))
525-
if path.is_dir() and os.listdir(path):
526-
raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}")
525+
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
526+
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
527527

528528
if self._state_dict_type == "sharded":
529529
from torch.distributed.checkpoint import FileSystemWriter, save_state_dict
530530

531+
if path.is_file():
532+
path.unlink()
531533
path.mkdir(parents=True, exist_ok=True)
532534

533535
converted_state = {"model": checkpoint.pop("state_dict")}
@@ -542,6 +544,8 @@ def save_checkpoint(
542544
if self.global_rank == 0:
543545
torch.save(checkpoint, path / _METADATA_FILENAME)
544546
elif self._state_dict_type == "full":
547+
if _is_sharded_checkpoint(path):
548+
shutil.rmtree(path)
545549
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
546550
else:
547551
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_FSDPBackwardSyncControl,
2828
_get_full_state_dict_context,
2929
_has_meta_device_parameters,
30+
_is_sharded_checkpoint,
3031
)
3132
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
3233
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
@@ -275,15 +276,61 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path):
275276

276277

277278
@RunIf(min_torch="2.0.0")
279+
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
278280
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
279-
def test_fsdp_save_checkpoint_folder_exists(tmp_path):
280-
path = tmp_path / "exists"
281+
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
282+
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
283+
@mock.patch("lightning.fabric.strategies.fsdp.torch.save", return_value=Mock())
284+
@mock.patch("lightning.fabric.strategies.fsdp.shutil", return_value=MagicMock())
285+
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
286+
strategy = FSDPStrategy(state_dict_type="full")
287+
288+
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
289+
path = tmp_path / "not-empty"
281290
path.mkdir()
282291
(path / "file").touch()
283-
strategy = FSDPStrategy()
284-
with pytest.raises(FileExistsError, match="exists and is not empty"):
292+
assert not _is_sharded_checkpoint(path)
293+
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
285294
strategy.save_checkpoint(path=path, state=Mock())
286295

296+
# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
297+
path = tmp_path / "sharded-checkpoint"
298+
path.mkdir()
299+
(path / "meta.pt").touch()
300+
assert _is_sharded_checkpoint(path)
301+
model = Mock(spec=FullyShardedDataParallel)
302+
model.modules.return_value = [model]
303+
strategy.save_checkpoint(path=path, state={"model": model})
304+
shutil_mock.rmtree.assert_called_once_with(path)
305+
306+
# state_dict_type='full', path exists, path is a file: no error (overwrite)
307+
path = tmp_path / "file.pt"
308+
path.touch()
309+
model = Mock(spec=FullyShardedDataParallel)
310+
model.modules.return_value = [model]
311+
torch_save_mock.reset_mock()
312+
strategy.save_checkpoint(path=path, state={"model": model})
313+
torch_save_mock.assert_called_once()
314+
315+
strategy = FSDPStrategy(state_dict_type="sharded")
316+
317+
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
318+
path = tmp_path / "not-empty-2"
319+
path.mkdir()
320+
(path / "file").touch()
321+
model = Mock(spec=FullyShardedDataParallel)
322+
model.modules.return_value = [model]
323+
strategy.save_checkpoint(path=path, state={"model": model})
324+
assert (path / "file").exists()
325+
326+
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
327+
path = tmp_path / "file-2.pt"
328+
path.touch()
329+
model = Mock(spec=FullyShardedDataParallel)
330+
model.modules.return_value = [model]
331+
strategy.save_checkpoint(path=path, state={"model": model})
332+
assert path.is_dir()
333+
287334

288335
@RunIf(min_torch="2.0.0")
289336
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
import torch.nn as nn
1515
from lightning.fabric.plugins.environments import LightningEnvironment
16+
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
1617
from lightning.fabric.utilities.imports import (
1718
_TORCH_GREATER_EQUAL_2_0,
1819
_TORCH_GREATER_EQUAL_2_1,
@@ -760,14 +761,61 @@ def test_save_checkpoint_storage_options(tmp_path):
760761
strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock())
761762

762763

764+
@RunIf(min_torch="2.0.0")
765+
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
763766
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
764-
def test_save_checkpoint_folder_exists(tmp_path):
765-
path = tmp_path / "exists"
767+
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
768+
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
769+
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save", return_value=Mock())
770+
@mock.patch("lightning.pytorch.strategies.fsdp.shutil", return_value=MagicMock())
771+
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
772+
strategy = FSDPStrategy(state_dict_type="full")
773+
774+
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
775+
path = tmp_path / "not-empty"
766776
path.mkdir()
767777
(path / "file").touch()
768-
strategy = FSDPStrategy()
769-
with pytest.raises(FileExistsError, match="exists and is not empty"):
770-
strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock())
778+
assert not _is_sharded_checkpoint(path)
779+
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
780+
strategy.save_checkpoint(Mock(), filepath=path)
781+
782+
# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
783+
path = tmp_path / "sharded-checkpoint"
784+
path.mkdir()
785+
(path / "meta.pt").touch()
786+
assert _is_sharded_checkpoint(path)
787+
model = Mock(spec=FullyShardedDataParallel)
788+
model.modules.return_value = [model]
789+
strategy.save_checkpoint(Mock(), filepath=path)
790+
shutil_mock.rmtree.assert_called_once_with(path)
791+
792+
# state_dict_type='full', path exists, path is a file: no error (overwrite)
793+
path = tmp_path / "file.pt"
794+
path.touch()
795+
model = Mock(spec=FullyShardedDataParallel)
796+
model.modules.return_value = [model]
797+
torch_save_mock.reset_mock()
798+
strategy.save_checkpoint(Mock(), filepath=path)
799+
torch_save_mock.assert_called_once()
800+
801+
strategy = FSDPStrategy(state_dict_type="sharded")
802+
803+
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
804+
path = tmp_path / "not-empty-2"
805+
path.mkdir()
806+
(path / "file").touch()
807+
model = Mock(spec=FullyShardedDataParallel)
808+
model.modules.return_value = [model]
809+
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
810+
assert (path / "file").exists()
811+
812+
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
813+
path = tmp_path / "file-2.pt"
814+
path.touch()
815+
model = Mock(spec=FullyShardedDataParallel)
816+
model.modules.return_value = [model]
817+
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
818+
assert path.is_dir()
771819

772820

773821
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)

0 commit comments

Comments
 (0)