|
16 | 16 | from functools import partial |
17 | 17 | from io import BytesIO |
18 | 18 | from pathlib import Path |
19 | | -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union |
| 19 | +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | from lightning_utilities.core.apply_func import apply_to_collection |
|
27 | 27 |
|
28 | 28 | from lightning.fabric.utilities.imports import ( |
29 | 29 | _TORCH_GREATER_EQUAL_2_0, |
30 | | - _TORCH_GREATER_EQUAL_2_1, |
31 | | - _TORCH_GREATER_EQUAL_2_2, |
| 30 | + _TORCH_GREATER_EQUAL_2_3, |
32 | 31 | ) |
33 | 32 | from lightning.fabric.utilities.types import _PATH, _Stateful |
34 | 33 |
|
@@ -243,68 +242,24 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: |
243 | 242 | The current implementation assumes that the entire checkpoint fits in CPU memory. |
244 | 243 |
|
245 | 244 | """ |
246 | | - if not _TORCH_GREATER_EQUAL_2_1: |
247 | | - raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.") |
| 245 | + if not _TORCH_GREATER_EQUAL_2_3: |
| 246 | + raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.3.") |
248 | 247 |
|
249 | 248 | from torch.distributed.checkpoint import FileSystemReader |
250 | | - from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata |
| 249 | + from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner |
| 250 | + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict |
251 | 251 |
|
252 | | - if _TORCH_GREATER_EQUAL_2_2: |
253 | | - from torch.distributed.checkpoint import load |
254 | | - else: |
255 | | - from torch.distributed.checkpoint import load_state_dict as load # deprecated |
256 | | - |
257 | | - reader = FileSystemReader(checkpoint_folder) |
258 | | - metadata = reader.read_metadata() |
259 | | - |
260 | | - # TODO: Add sequential save to avoid storing the entire checkpoint in memory |
261 | 252 | checkpoint: Dict[str, Any] = {} |
262 | | - for tensor_name, sd_metadata in metadata.state_dict_metadata.items(): |
263 | | - if isinstance(sd_metadata, BytesStorageMetadata): |
264 | | - checkpoint[tensor_name] = "<bytes_io>" |
265 | | - elif isinstance(sd_metadata, TensorStorageMetadata): |
266 | | - checkpoint[tensor_name] = torch.empty( |
267 | | - size=sd_metadata.size, |
268 | | - dtype=sd_metadata.properties.dtype, |
269 | | - device=torch.device("cpu"), |
270 | | - memory_format=sd_metadata.properties.memory_format, |
271 | | - layout=sd_metadata.properties.layout, |
272 | | - requires_grad=sd_metadata.properties.requires_grad, |
273 | | - pin_memory=sd_metadata.properties.pin_memory, |
274 | | - ) |
275 | | - |
276 | | - load(state_dict=checkpoint, storage_reader=reader, no_dist=True) |
277 | | - checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data) |
| 253 | + _load_state_dict( |
| 254 | + checkpoint, |
| 255 | + storage_reader=FileSystemReader(checkpoint_folder), |
| 256 | + planner=_EmptyStateDictLoadPlanner(), |
| 257 | + no_dist=True, |
| 258 | + ) |
278 | 259 |
|
279 | 260 | # This is the extra file saved by Fabric, with user data separate from weights and optimizer states |
280 | 261 | extra_file = checkpoint_folder / _METADATA_FILENAME |
281 | 262 | extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {} |
282 | 263 | checkpoint.update(extra) |
283 | 264 |
|
284 | 265 | return checkpoint |
285 | | - |
286 | | - |
287 | | -def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]: |
288 | | - """Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map. |
289 | | -
|
290 | | - Args: |
291 | | - checkpoint: The flat checkpoint dictionary. |
292 | | - key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing |
293 | | - the index path into the nested dictonary that this function should construct. |
294 | | -
|
295 | | - """ |
296 | | - assert checkpoint.keys() == key_map.keys() |
297 | | - converted: Dict[str, Any] = {} |
298 | | - for flat_key in checkpoint: |
299 | | - key_path = key_map[flat_key] |
300 | | - _set_nested_dict_value(converted, key_path, checkpoint[flat_key]) |
301 | | - return converted |
302 | | - |
303 | | - |
304 | | -def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None: |
305 | | - result = nested_dict |
306 | | - for key in key_path[:-1]: |
307 | | - if key not in result: |
308 | | - result[key] = {} |
309 | | - result = result[key] |
310 | | - result[key_path[-1]] = value |
0 commit comments