|
16 | 16 | from functools import partial
|
17 | 17 | from io import BytesIO
|
18 | 18 | from pathlib import Path
|
19 |
| -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union |
| 19 | +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union |
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 | from lightning_utilities.core.apply_func import apply_to_collection
|
|
27 | 27 |
|
28 | 28 | from lightning.fabric.utilities.imports import (
|
29 | 29 | _TORCH_GREATER_EQUAL_2_0,
|
30 |
| - _TORCH_GREATER_EQUAL_2_1, |
31 |
| - _TORCH_GREATER_EQUAL_2_2, |
| 30 | + _TORCH_GREATER_EQUAL_2_3, |
32 | 31 | )
|
33 | 32 | from lightning.fabric.utilities.types import _PATH, _Stateful
|
34 | 33 |
|
@@ -243,68 +242,24 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]:
|
243 | 242 | The current implementation assumes that the entire checkpoint fits in CPU memory.
|
244 | 243 |
|
245 | 244 | """
|
246 |
| - if not _TORCH_GREATER_EQUAL_2_1: |
247 |
| - raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.") |
| 245 | + if not _TORCH_GREATER_EQUAL_2_3: |
| 246 | + raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.") |
248 | 247 |
|
249 | 248 | from torch.distributed.checkpoint import FileSystemReader
|
250 |
| - from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata |
| 249 | + from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner |
| 250 | + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict |
251 | 251 |
|
252 |
| - if _TORCH_GREATER_EQUAL_2_2: |
253 |
| - from torch.distributed.checkpoint import load |
254 |
| - else: |
255 |
| - from torch.distributed.checkpoint import load_state_dict as load # deprecated |
256 |
| - |
257 |
| - reader = FileSystemReader(checkpoint_folder) |
258 |
| - metadata = reader.read_metadata() |
259 |
| - |
260 |
| - # TODO: Add sequential save to avoid storing the entire checkpoint in memory |
261 | 252 | checkpoint: Dict[str, Any] = {}
|
262 |
| - for tensor_name, sd_metadata in metadata.state_dict_metadata.items(): |
263 |
| - if isinstance(sd_metadata, BytesStorageMetadata): |
264 |
| - checkpoint[tensor_name] = "<bytes_io>" |
265 |
| - elif isinstance(sd_metadata, TensorStorageMetadata): |
266 |
| - checkpoint[tensor_name] = torch.empty( |
267 |
| - size=sd_metadata.size, |
268 |
| - dtype=sd_metadata.properties.dtype, |
269 |
| - device=torch.device("cpu"), |
270 |
| - memory_format=sd_metadata.properties.memory_format, |
271 |
| - layout=sd_metadata.properties.layout, |
272 |
| - requires_grad=sd_metadata.properties.requires_grad, |
273 |
| - pin_memory=sd_metadata.properties.pin_memory, |
274 |
| - ) |
275 |
| - |
276 |
| - load(state_dict=checkpoint, storage_reader=reader, no_dist=True) |
277 |
| - checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data) |
| 253 | + _load_state_dict( |
| 254 | + checkpoint, |
| 255 | + storage_reader=FileSystemReader(checkpoint_folder), |
| 256 | + planner=_EmptyStateDictLoadPlanner(), |
| 257 | + no_dist=True, |
| 258 | + ) |
278 | 259 |
|
279 | 260 | # This is the extra file saved by Fabric, with user data separate from weights and optimizer states
|
280 | 261 | extra_file = checkpoint_folder / _METADATA_FILENAME
|
281 | 262 | extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
|
282 | 263 | checkpoint.update(extra)
|
283 | 264 |
|
284 | 265 | return checkpoint
|
285 |
| - |
286 |
| - |
287 |
| -def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]: |
288 |
| - """Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map. |
289 |
| -
|
290 |
| - Args: |
291 |
| - checkpoint: The flat checkpoint dictionary. |
292 |
| - key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing |
293 |
| - the index path into the nested dictonary that this function should construct. |
294 |
| -
|
295 |
| - """ |
296 |
| - assert checkpoint.keys() == key_map.keys() |
297 |
| - converted: Dict[str, Any] = {} |
298 |
| - for flat_key in checkpoint: |
299 |
| - key_path = key_map[flat_key] |
300 |
| - _set_nested_dict_value(converted, key_path, checkpoint[flat_key]) |
301 |
| - return converted |
302 |
| - |
303 |
| - |
304 |
| -def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None: |
305 |
| - result = nested_dict |
306 |
| - for key in key_path[:-1]: |
307 |
| - if key not in result: |
308 |
| - result[key] = {} |
309 |
| - result = result[key] |
310 |
| - result[key_path[-1]] = value |
0 commit comments