Skip to content

Commit b3c869f

Browse files
authored
Revise checkpoint consolidation with PyTorch 2.3 (#19561)
1 parent 527d071 commit b3c869f

File tree

6 files changed

+23
-91
lines changed

6 files changed

+23
-91
lines changed

src/lightning/fabric/utilities/consolidate_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
7+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
88
from lightning.fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint
99

1010
_log = logging.getLogger(__name__)
@@ -38,8 +38,8 @@ def _parse_cli_args() -> Namespace:
3838

3939

4040
def _process_cli_args(args: Namespace) -> Namespace:
41-
if not _TORCH_GREATER_EQUAL_2_1:
42-
_log.error("Processing distributed checkpoints requires PyTorch >= 2.1.")
41+
if not _TORCH_GREATER_EQUAL_2_3:
42+
_log.error("Processing distributed checkpoints requires PyTorch >= 2.3.")
4343
exit(1)
4444

4545
checkpoint_folder = Path(args.checkpoint_folder)

src/lightning/fabric/utilities/load.py

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import partial
1717
from io import BytesIO
1818
from pathlib import Path
19-
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union
19+
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union
2020

2121
import torch
2222
from lightning_utilities.core.apply_func import apply_to_collection
@@ -27,8 +27,7 @@
2727

2828
from lightning.fabric.utilities.imports import (
2929
_TORCH_GREATER_EQUAL_2_0,
30-
_TORCH_GREATER_EQUAL_2_1,
31-
_TORCH_GREATER_EQUAL_2_2,
30+
_TORCH_GREATER_EQUAL_2_3,
3231
)
3332
from lightning.fabric.utilities.types import _PATH, _Stateful
3433

@@ -243,68 +242,24 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]:
243242
The current implementation assumes that the entire checkpoint fits in CPU memory.
244243
245244
"""
246-
if not _TORCH_GREATER_EQUAL_2_1:
247-
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.")
245+
if not _TORCH_GREATER_EQUAL_2_3:
246+
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.")
248247

249248
from torch.distributed.checkpoint import FileSystemReader
250-
from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata
249+
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
250+
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
251251

252-
if _TORCH_GREATER_EQUAL_2_2:
253-
from torch.distributed.checkpoint import load
254-
else:
255-
from torch.distributed.checkpoint import load_state_dict as load # deprecated
256-
257-
reader = FileSystemReader(checkpoint_folder)
258-
metadata = reader.read_metadata()
259-
260-
# TODO: Add sequential save to avoid storing the entire checkpoint in memory
261252
checkpoint: Dict[str, Any] = {}
262-
for tensor_name, sd_metadata in metadata.state_dict_metadata.items():
263-
if isinstance(sd_metadata, BytesStorageMetadata):
264-
checkpoint[tensor_name] = "<bytes_io>"
265-
elif isinstance(sd_metadata, TensorStorageMetadata):
266-
checkpoint[tensor_name] = torch.empty(
267-
size=sd_metadata.size,
268-
dtype=sd_metadata.properties.dtype,
269-
device=torch.device("cpu"),
270-
memory_format=sd_metadata.properties.memory_format,
271-
layout=sd_metadata.properties.layout,
272-
requires_grad=sd_metadata.properties.requires_grad,
273-
pin_memory=sd_metadata.properties.pin_memory,
274-
)
275-
276-
load(state_dict=checkpoint, storage_reader=reader, no_dist=True)
277-
checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data)
253+
_load_state_dict(
254+
checkpoint,
255+
storage_reader=FileSystemReader(checkpoint_folder),
256+
planner=_EmptyStateDictLoadPlanner(),
257+
no_dist=True,
258+
)
278259

279260
# This is the extra file saved by Fabric, with user data separate from weights and optimizer states
280261
extra_file = checkpoint_folder / _METADATA_FILENAME
281262
extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
282263
checkpoint.update(extra)
283264

284265
return checkpoint
285-
286-
287-
def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]:
288-
"""Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map.
289-
290-
Args:
291-
checkpoint: The flat checkpoint dictionary.
292-
key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing
293-
the index path into the nested dictonary that this function should construct.
294-
295-
"""
296-
assert checkpoint.keys() == key_map.keys()
297-
converted: Dict[str, Any] = {}
298-
for flat_key in checkpoint:
299-
key_path = key_map[flat_key]
300-
_set_nested_dict_value(converted, key_path, checkpoint[flat_key])
301-
return converted
302-
303-
304-
def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None:
305-
result = nested_dict
306-
for key in key_path[:-1]:
307-
if key not in result:
308-
result[key] = {}
309-
result = result[key]
310-
result[key_path[-1]] = value

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,7 @@ def test_clip_gradients(clip_type, precision):
621621
optimizer.zero_grad()
622622

623623

624-
# TODO: Support checkpoint consolidation with PyTorch >= 2.2
625-
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
624+
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
626625
def test_save_sharded_and_consolidate_and_load(tmp_path):
627626
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
628627

@@ -639,7 +638,8 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
639638
state = {"model": model, "optimizer": optimizer, "steps": 1}
640639

641640
# run one iteration to init the state of the optimizer
642-
model(torch.rand(1, 32, device=fabric.device)).sum().backward()
641+
loss = model(torch.rand(1, 32, device=fabric.device)).sum()
642+
fabric.backward(loss)
643643
optimizer.step()
644644

645645
checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))

tests/tests_fabric/utilities/test_consolidate_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ def test_parse_cli_args(args, expected):
3939

4040

4141
def test_process_cli_args(tmp_path, caplog, monkeypatch):
42-
# PyTorch version < 2.1
43-
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", False)
42+
# PyTorch version < 2.3
43+
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False)
4444
with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises(
4545
SystemExit
4646
):
4747
_process_cli_args(Namespace())
48-
assert "requires PyTorch >= 2.1." in caplog.text
48+
assert "requires PyTorch >= 2.3." in caplog.text
4949
caplog.clear()
50-
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", True)
50+
monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", True)
5151

5252
# Checkpoint does not exist
5353
checkpoint_folder = Path("does/not/exist")

tests/tests_fabric/utilities/test_load.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
_materialize_tensors,
2020
_move_state_into,
2121
_NotYetLoadedTensor,
22-
_unflatten_dict,
2322
)
2423

2524
from tests_fabric.helpers.runif import RunIf
@@ -145,24 +144,3 @@ def load_state_dict(self, state_dict):
145144
assert source == {}
146145
assert destination["cocofruit"] == 2
147146
assert destination["banana"].count == 100
148-
149-
150-
def test_unflatten_dict():
151-
assert _unflatten_dict({}, {}) == {}
152-
153-
tensor0 = torch.rand(2, 2)
154-
tensor1 = torch.tensor(3.0)
155-
data = {
156-
"model.layer.weight": tensor0,
157-
"optimizer.state.layer.weight.exp_avg": {"test": tensor1},
158-
"optimizer.param_groups": "param_groups",
159-
}
160-
key_map = {
161-
"model.layer.weight": ("model", "layer.weight"),
162-
"optimizer.state.layer.weight.exp_avg": ("optimizer", "state", "layer.weight", "exp_avg"),
163-
"optimizer.param_groups": ("optimizer", "param_groups"),
164-
}
165-
assert _unflatten_dict(data, key_map) == {
166-
"model": {"layer.weight": tensor0},
167-
"optimizer": {"state": {"layer.weight": {"exp_avg": {"test": tensor1}}}, "param_groups": "param_groups"},
168-
}

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,8 +1013,7 @@ def _run_setup_assertions(empty_init, expected_device):
10131013
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
10141014

10151015

1016-
# TODO: Support checkpoint consolidation with PyTorch >= 2.2
1017-
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
1016+
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
10181017
def test_save_sharded_and_consolidate_and_load(tmp_path):
10191018
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
10201019

0 commit comments

Comments
 (0)