Skip to content

Commit 13f15b3

Browse files
authored
Support consolidating sharded checkpoints with the fabric CLI (#19560)
1 parent d9113b6 commit 13f15b3

File tree

4 files changed

+70
-17
lines changed

4 files changed

+70
-17
lines changed

docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ It is possible to convert a distributed checkpoint to a regular, single-file che
187187

188188
.. code-block:: bash
189189
190-
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
190+
fabric consolidate path/to/my/checkpoint
191191
192192
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.
193193

@@ -202,7 +202,7 @@ You will need to do this for example if you want to load the checkpoint into a s
202202

203203
.. code-block:: bash
204204
205-
python -m lightning.fabric.utilities.consolidate_checkpoint my-checkpoint.ckpt
205+
fabric consolidate my-checkpoint.ckpt
206206
207207
This saves a new file ``my-checkpoint.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:
208208

src/lightning/fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
12+
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
1313

1414
-
1515

src/lightning/fabric/cli.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
from argparse import Namespace
2020
from typing import Any, List, Optional
2121

22+
import torch
2223
from lightning_utilities.core.imports import RequirementCache
2324
from typing_extensions import get_args
2425

2526
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
2627
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
2728
from lightning.fabric.strategies import STRATEGY_REGISTRY
29+
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
2830
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
2931
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
32+
from lightning.fabric.utilities.load import _load_distributed_checkpoint
3033

3134
_log = logging.getLogger(__name__)
3235

@@ -154,6 +157,37 @@ def _run(**kwargs: Any) -> None:
154157
script_args = list(kwargs.pop("script_args", []))
155158
main(args=Namespace(**kwargs), script_args=script_args)
156159

160+
@_main.command(
161+
"consolidate",
162+
context_settings={
163+
"ignore_unknown_options": True,
164+
},
165+
)
166+
@click.argument(
167+
"checkpoint_folder",
168+
type=click.Path(exists=True),
169+
)
170+
@click.option(
171+
"--output_file",
172+
type=click.Path(exists=True),
173+
default=None,
174+
help=(
175+
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
176+
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
177+
" and a '.consolidated' suffix."
178+
),
179+
)
180+
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
181+
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
182+
183+
Only supports FSDP sharded checkpoints at the moment.
184+
185+
"""
186+
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
187+
config = _process_cli_args(args)
188+
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
189+
torch.save(checkpoint, config.output_file)
190+
157191

158192
def _set_env_variables(args: Namespace) -> None:
159193
"""Set the environment variables for the new processes.

tests/tests_fabric/test_cli.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from unittest.mock import Mock
2121

2222
import pytest
23-
from lightning.fabric.cli import _get_supported_strategies, _run
23+
from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run
2424

2525
from tests_fabric.helpers.runif import RunIf
2626

@@ -33,7 +33,7 @@ def fake_script(tmp_path):
3333

3434

3535
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
36-
def test_cli_env_vars_defaults(monkeypatch, fake_script):
36+
def test_run_env_vars_defaults(monkeypatch, fake_script):
3737
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
3838
with pytest.raises(SystemExit) as e:
3939
_run.main([fake_script])
@@ -49,7 +49,7 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
4949
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
5050
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
5151
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
52-
def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
52+
def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
5353
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
5454
with pytest.raises(SystemExit) as e:
5555
_run.main([fake_script, "--accelerator", accelerator])
@@ -60,23 +60,23 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
6060
@pytest.mark.parametrize("strategy", _get_supported_strategies())
6161
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
6262
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
63-
def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
63+
def test_run_env_vars_strategy(_, strategy, monkeypatch, fake_script):
6464
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
6565
with pytest.raises(SystemExit) as e:
6666
_run.main([fake_script, "--strategy", strategy])
6767
assert e.value.code == 0
6868
assert os.environ["LT_STRATEGY"] == strategy
6969

7070

71-
def test_cli_get_supported_strategies():
71+
def test_run_get_supported_strategies():
7272
"""Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the
7373
CLI."""
7474
assert len(_get_supported_strategies()) == 7
7575
assert "fsdp" in _get_supported_strategies()
7676

7777

7878
@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork", "ddp_notebook", "deepspeed_stage_3_offload"])
79-
def test_cli_env_vars_unsupported_strategy(strategy, fake_script):
79+
def test_run_env_vars_unsupported_strategy(strategy, fake_script):
8080
ioerr = StringIO()
8181
with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr):
8282
_run.main([fake_script, "--strategy", strategy])
@@ -87,7 +87,7 @@ def test_cli_env_vars_unsupported_strategy(strategy, fake_script):
8787
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
8888
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
8989
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
90-
def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
90+
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
9191
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
9292
with pytest.raises(SystemExit) as e:
9393
_run.main([fake_script, "--accelerator", "cuda", "--devices", devices])
@@ -98,7 +98,7 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
9898
@RunIf(mps=True)
9999
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
100100
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
101-
def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
101+
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
102102
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
103103
with pytest.raises(SystemExit) as e:
104104
_run.main([fake_script, "--accelerator", accelerator])
@@ -108,7 +108,7 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
108108

109109
@pytest.mark.parametrize("num_nodes", ["1", "2", "3"])
110110
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
111-
def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
111+
def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
112112
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
113113
with pytest.raises(SystemExit) as e:
114114
_run.main([fake_script, "--num-nodes", num_nodes])
@@ -118,7 +118,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
118118

119119
@pytest.mark.parametrize("precision", ["64-true", "64", "32-true", "32", "16-mixed", "bf16-mixed"])
120120
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
121-
def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
121+
def test_run_env_vars_precision(precision, monkeypatch, fake_script):
122122
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
123123
with pytest.raises(SystemExit) as e:
124124
_run.main([fake_script, "--precision", precision])
@@ -127,7 +127,7 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
127127

128128

129129
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
130-
def test_cli_torchrun_defaults(monkeypatch, fake_script):
130+
def test_run_torchrun_defaults(monkeypatch, fake_script):
131131
torchrun_mock = Mock()
132132
monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock)
133133
with pytest.raises(SystemExit) as e:
@@ -155,7 +155,7 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
155155
)
156156
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
157157
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5)
158-
def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script):
158+
def test_run_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script):
159159
torchrun_mock = Mock()
160160
monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock)
161161
with pytest.raises(SystemExit) as e:
@@ -171,15 +171,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
171171
])
172172

173173

174-
def test_cli_through_fabric_entry_point():
174+
def test_run_through_fabric_entry_point():
175175
result = subprocess.run("fabric run --help", capture_output=True, text=True, shell=True)
176176

177177
message = "Usage: fabric run [OPTIONS] SCRIPT [SCRIPT_ARGS]"
178178
assert message in result.stdout or message in result.stderr
179179

180180

181181
@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
182-
def test_cli_through_lightning_entry_point():
182+
def test_run_through_lightning_entry_point():
183183
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)
184184

185185
deprecation_message = (
@@ -189,3 +189,22 @@ def test_cli_through_lightning_entry_point():
189189
message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]"
190190
assert deprecation_message in result.stdout
191191
assert message in result.stdout or message in result.stderr
192+
193+
194+
@mock.patch("lightning.fabric.cli._process_cli_args")
195+
@mock.patch("lightning.fabric.cli._load_distributed_checkpoint")
196+
@mock.patch("lightning.fabric.cli.torch.save")
197+
def test_consolidate(save_mock, _, __, tmp_path):
198+
ioerr = StringIO()
199+
with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr):
200+
_consolidate.main(["not exist"])
201+
assert e.value.code == 2
202+
assert "Path 'not exist' does not exist" in ioerr.getvalue()
203+
204+
checkpoint_folder = tmp_path / "checkpoint"
205+
checkpoint_folder.mkdir()
206+
ioerr = StringIO()
207+
with pytest.raises(SystemExit) as e, contextlib.redirect_stderr(ioerr):
208+
_consolidate.main([str(checkpoint_folder)])
209+
assert e.value.code == 0
210+
save_mock.assert_called_once()

0 commit comments

Comments
 (0)