Skip to content

Commit 414c863

Browse files
awaelchlilantiga
andauthored
(9/n) Support 2D Parallelism - Remaining Checkpoint Logic (#19888)
Co-authored-by: Luca Antiga <[email protected]>
1 parent fa1126e commit 414c863

File tree

6 files changed

+334
-61
lines changed

6 files changed

+334
-61
lines changed

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def _load_checkpoint(
412412
path: Path,
413413
state: Dict[str, Union[Module, Optimizer, Any]],
414414
strict: bool = True,
415+
optimizer_states_from_list: bool = False,
415416
) -> Dict[str, Any]:
416417
from torch.distributed.checkpoint.state_dict import (
417418
StateDictOptions,
@@ -473,8 +474,15 @@ def _load_checkpoint(
473474
full_state_dict=True,
474475
strict=strict,
475476
)
476-
for optimizer_name, optimizer in optimizers.items():
477-
optimizer_state = _rekey_optimizer_state_if_needed(checkpoint.pop(optimizer_name), module)
477+
for optimizer_idx, (optimizer_name, optimizer) in enumerate(optimizers.items()):
478+
if optimizer_states_from_list:
479+
# This code path is only used by `lightning.pytorch`, which saves optimizer states as a list
480+
# rather than individual states at the top level.
481+
optimizer_state = checkpoint["optimizer_states"][optimizer_idx]
482+
else:
483+
optimizer_state = checkpoint.pop(optimizer_name)
484+
485+
optimizer_state = _rekey_optimizer_state_if_needed(optimizer_state, module)
478486
set_optimizer_state_dict(
479487
module,
480488
optimizer,

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
2020

21+
- Added `ModelParallelStrategy` to support 2D parallelism ([#19878](https://github.com/Lightning-AI/pytorch-lightning/pull/19878), [#19888](https://github.com/Lightning-AI/pytorch-lightning/pull/19888))
22+
23+
2124

2225
### Changed
2326

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import shutil
1415
from contextlib import contextmanager, nullcontext
1516
from datetime import timedelta
17+
from pathlib import Path
1618
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union
1719

1820
import torch
@@ -22,9 +24,13 @@
2224
from typing_extensions import override
2325

2426
import lightning.pytorch as pl
25-
from lightning.fabric.plugins import CheckpointIO
2627
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
27-
from lightning.fabric.strategies.model_parallel import _setup_device_mesh
28+
from lightning.fabric.strategies.model_parallel import (
29+
_distributed_checkpoint_save,
30+
_is_sharded_checkpoint,
31+
_load_checkpoint,
32+
_setup_device_mesh,
33+
)
2834
from lightning.fabric.utilities.distributed import (
2935
_distributed_is_initialized,
3036
_get_default_process_group_backend_for_device,
@@ -34,6 +40,7 @@
3440
from lightning.fabric.utilities.distributed import group as _group
3541
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
3642
from lightning.fabric.utilities.init import _materialize_distributed_module
43+
from lightning.fabric.utilities.load import _METADATA_FILENAME
3744
from lightning.fabric.utilities.optimizer import _optimizers_to_device
3845
from lightning.fabric.utilities.seed import reset_seed
3946
from lightning.fabric.utilities.types import _PATH, ReduceOp
@@ -95,16 +102,6 @@ def device_mesh(self) -> "DeviceMesh":
95102
raise RuntimeError("Accessing the device mesh before processes have initialized is not allowed.")
96103
return self._device_mesh
97104

98-
@property
99-
@override
100-
def checkpoint_io(self) -> CheckpointIO:
101-
raise NotImplementedError(f"The `{type(self).__name__}` does not use the `CheckpointIO` plugin interface.")
102-
103-
@checkpoint_io.setter
104-
@override
105-
def checkpoint_io(self, io: CheckpointIO) -> None:
106-
raise NotImplementedError(f"The `{type(self).__name__}` does not support setting a `CheckpointIO` plugin.")
107-
108105
@property
109106
@override
110107
def root_device(self) -> torch.device:
@@ -253,6 +250,11 @@ def teardown(self) -> None:
253250

254251
@override
255252
def lightning_module_state_dict(self) -> Dict[str, Any]:
253+
"""Collects the state dict of the model.
254+
255+
Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``.
256+
257+
"""
256258
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
257259

258260
state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True)
@@ -266,6 +268,11 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
266268

267269
@override
268270
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]:
271+
"""Collects the state of the given optimizer.
272+
273+
Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``.
274+
275+
"""
269276
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
270277
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
271278
from torch.distributed.fsdp import OptimStateKeyType
@@ -275,8 +282,9 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]:
275282
optimizer = optimizer._optimizer
276283

277284
assert self.model is not None
285+
278286
state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
279-
if not self._save_distributed_checkpoint:
287+
if not self._save_distributed_checkpoint and self.global_rank == 0:
280288
# Store the optimizer state dict in standard format
281289
state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
282290
return state_dict
@@ -295,11 +303,45 @@ def save_checkpoint(
295303
f"`{type(self).__name__}.save_checkpoint(..., storage_options=...)` is not supported because"
296304
f" `{type(self).__name__}` does not use the `CheckpointIO`."
297305
)
298-
raise NotImplementedError("Checkpoint saving is not yet implemented.")
306+
# broadcast the path from rank 0 to ensure all the checkpoints are saved to a common path
307+
path = Path(self.broadcast(filepath))
308+
if path.is_dir() and not self._save_distributed_checkpoint and not _is_sharded_checkpoint(path):
309+
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
310+
311+
if self._save_distributed_checkpoint:
312+
if path.is_file():
313+
path.unlink()
314+
path.mkdir(parents=True, exist_ok=True)
315+
316+
converted_state = {"state_dict": checkpoint.pop("state_dict")}
317+
converted_state.update({
318+
f"optimizer_{idx}": optim_state
319+
for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
320+
})
321+
_distributed_checkpoint_save(converted_state, path)
322+
323+
if self.global_rank == 0:
324+
torch.save(checkpoint, path / _METADATA_FILENAME)
325+
else:
326+
if _is_sharded_checkpoint(path):
327+
shutil.rmtree(path)
328+
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
299329

300330
@override
301331
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
302-
raise NotImplementedError("Checkpoint loading is not yet implemented.")
332+
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
333+
path = Path(self.broadcast(checkpoint_path))
334+
state = {
335+
"state_dict": self.model,
336+
**{f"optimizer_{idx}": optimizer for idx, optimizer in enumerate(self.optimizers)},
337+
}
338+
assert self.lightning_module is not None
339+
return _load_checkpoint(
340+
path=path,
341+
state=state,
342+
strict=self.lightning_module.strict_loading,
343+
optimizer_states_from_list=True,
344+
)
303345

304346
def _setup_distributed(self) -> None:
305347
super().setup_environment()

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ def test_invalid_on_cpu(tmp_path, cuda_count_0):
210210
trainer.strategy.setup_environment()
211211

212212

213-
def test_fsdp_custom_mixed_precision():
213+
def test_custom_mixed_precision():
214214
"""Test to ensure that passing a custom mixed precision config works."""
215215
config = MixedPrecision()
216216
strategy = FSDPStrategy(mixed_precision=config)
217217
assert strategy.mixed_precision_config == config
218218

219219

220220
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
221-
def test_fsdp_strategy_sync_batchnorm(tmp_path):
221+
def test_strategy_sync_batchnorm(tmp_path):
222222
"""Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run."""
223223
model = TestFSDPModel()
224224
trainer = Trainer(
@@ -234,7 +234,7 @@ def test_fsdp_strategy_sync_batchnorm(tmp_path):
234234

235235

236236
@RunIf(min_cuda_gpus=1, skip_windows=True)
237-
def test_fsdp_modules_without_parameters(tmp_path):
237+
def test_modules_without_parameters(tmp_path):
238238
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
239239

240240
class MetricsModel(BoringModel):
@@ -266,7 +266,7 @@ def training_step(self, batch, batch_idx):
266266
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
267267
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
268268
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
269-
def test_fsdp_strategy_checkpoint(state_dict_type, precision, tmp_path):
269+
def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
270270
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
271271
model = TestFSDPModel()
272272
strategy = FSDPStrategy(state_dict_type=state_dict_type)
@@ -286,7 +286,7 @@ def custom_auto_wrap_policy(
286286

287287
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
288288
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
289-
def test_fsdp_strategy_full_state_dict(tmp_path, wrap_min_params):
289+
def test_strategy_full_state_dict(tmp_path, wrap_min_params):
290290
"""Test to ensure that the full state dict is extracted when using FSDP strategy.
291291
292292
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all.
@@ -342,7 +342,7 @@ def test_fsdp_strategy_full_state_dict(tmp_path, wrap_min_params):
342342
),
343343
],
344344
)
345-
def test_fsdp_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
345+
def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
346346
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
347347
ck = ModelCheckpoint(save_last=True)
348348

@@ -410,7 +410,7 @@ def configure_optimizers(self):
410410
trainer.fit(model)
411411

412412

413-
def test_fsdp_forbidden_precision_raises():
413+
def test_forbidden_precision_raises():
414414
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
415415
FSDPStrategy(precision_plugin=HalfPrecision())
416416

@@ -419,7 +419,7 @@ def test_fsdp_forbidden_precision_raises():
419419
strategy.precision_plugin = HalfPrecision()
420420

421421

422-
def test_fsdp_activation_checkpointing():
422+
def test_activation_checkpointing():
423423
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
424424

425425
class Block1(nn.Linear):
@@ -469,7 +469,7 @@ def __init__(self):
469469
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
470470

471471

472-
def test_fsdp_strategy_cpu_offload():
472+
def test_strategy_cpu_offload():
473473
"""Test the different ways cpu offloading can be enabled."""
474474
# bool
475475
strategy = FSDPStrategy(cpu_offload=True)
@@ -481,7 +481,7 @@ def test_fsdp_strategy_cpu_offload():
481481
assert strategy.cpu_offload == config
482482

483483

484-
def test_fsdp_sharding_strategy():
484+
def test_sharding_strategy():
485485
"""Test the different ways the sharding strategy can be set."""
486486
from torch.distributed.fsdp import ShardingStrategy
487487

@@ -501,7 +501,7 @@ def test_fsdp_sharding_strategy():
501501

502502

503503
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
504-
def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
504+
def test_hybrid_sharding_strategy(sharding_strategy):
505505
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
506506
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
507507
FSDPStrategy(sharding_strategy=sharding_strategy)
@@ -523,7 +523,7 @@ def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
523523
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
524524

525525

526-
def test_fsdp_use_orig_params():
526+
def test_use_orig_params():
527527
"""Test that Lightning enables `use_orig_params` automatically."""
528528
strategy = FSDPStrategy()
529529
assert strategy.kwargs["use_orig_params"]
@@ -548,7 +548,7 @@ def test_set_timeout(init_process_group_mock):
548548

549549

550550
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
551-
def test_fsdp_strategy_load_optimizer_states_multiple(_, tmp_path):
551+
def test_strategy_load_optimizer_states_multiple(_, tmp_path):
552552
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], state_dict_type="full")
553553
trainer = Trainer()
554554
trainer.state.fn = TrainerFn.FITTING
@@ -572,7 +572,7 @@ def test_fsdp_strategy_load_optimizer_states_multiple(_, tmp_path):
572572

573573
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
574574
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
575-
def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
575+
def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
576576
"""Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
577577
578578
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
@@ -630,7 +630,7 @@ def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
630630

631631
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
632632
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
633-
def test_fsdp_strategy_load_optimizer_states(wrap_min_params, tmp_path):
633+
def test_strategy_load_optimizer_states(wrap_min_params, tmp_path):
634634
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
635635
636636
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
@@ -741,7 +741,7 @@ def test_save_checkpoint_storage_options(tmp_path):
741741
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context")
742742
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
743743
@mock.patch("lightning.pytorch.strategies.fsdp.shutil")
744-
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
744+
def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
745745
strategy = FSDPStrategy(state_dict_type="full")
746746

747747
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
@@ -757,16 +757,12 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
757757
path.mkdir()
758758
(path / "meta.pt").touch()
759759
assert _is_sharded_checkpoint(path)
760-
model = Mock(spec=FullyShardedDataParallel)
761-
model.modules.return_value = [model]
762760
strategy.save_checkpoint(Mock(), filepath=path)
763761
shutil_mock.rmtree.assert_called_once_with(path)
764762

765763
# state_dict_type='full', path exists, path is a file: no error (overwrite)
766764
path = tmp_path / "file.pt"
767765
path.touch()
768-
model = Mock(spec=FullyShardedDataParallel)
769-
model.modules.return_value = [model]
770766
torch_save_mock.reset_mock()
771767
strategy.save_checkpoint(Mock(), filepath=path)
772768
torch_save_mock.assert_called_once()
@@ -783,30 +779,26 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
783779
path = tmp_path / "not-empty-2"
784780
path.mkdir()
785781
(path / "file").touch()
786-
model = Mock(spec=FullyShardedDataParallel)
787-
model.modules.return_value = [model]
788782
with save_mock:
789783
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
790784
assert (path / "file").exists()
791785

792786
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
793787
path = tmp_path / "file-2.pt"
794788
path.touch()
795-
model = Mock(spec=FullyShardedDataParallel)
796-
model.modules.return_value = [model]
797789
with save_mock:
798790
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
799791
assert path.is_dir()
800792

801793

802794
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
803-
def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
795+
def test_save_checkpoint_unknown_state_dict_type(tmp_path):
804796
strategy = FSDPStrategy(state_dict_type="invalid")
805797
with pytest.raises(ValueError, match="Unknown state_dict_type"):
806798
strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path)
807799

808800

809-
def test_fsdp_load_unknown_checkpoint_type(tmp_path):
801+
def test_load_unknown_checkpoint_type(tmp_path):
810802
"""Test that the strategy validates the contents at the checkpoint path."""
811803
strategy = FSDPStrategy()
812804
strategy.model = Mock()
@@ -874,7 +866,7 @@ def test_save_load_sharded_state_dict(tmp_path):
874866
@mock.patch("lightning.pytorch.strategies.fsdp.torch.load")
875867
@mock.patch("lightning.pytorch.strategies.fsdp._lazy_load")
876868
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
877-
def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path):
869+
def test_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path):
878870
"""Test that loading a single file (full state) is lazy to reduce peak CPU memory usage."""
879871
model = BoringModel()
880872
checkpoint = {"state_dict": model.state_dict()}

0 commit comments

Comments
 (0)