Skip to content

Commit 8bfbe0c

Browse files
authored
Fix strict loading from distributed checkpoints vs PyTorch nightly (#19946)
* strict loading * docstring
1 parent 19f0fb9 commit 8bfbe0c

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,7 @@ def load_checkpoint(
275275
state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
276276
strict: bool = True,
277277
) -> Dict[str, Any]:
278-
"""Load the contents from a checkpoint and restore the state of the given objects.
279-
280-
Currently does not support loading the optimizer state if the model is distributed but the checkpoint is a full,
281-
non-distributed checkpoint.
282-
283-
"""
278+
"""Load the contents from a checkpoint and restore the state of the given objects."""
284279
if not state:
285280
raise ValueError(
286281
f"Got {type(self).__name__}.load_checkpoint(..., state={state!r}) but a state with at least "
@@ -559,14 +554,14 @@ def _load_raw_module_state(
559554
state_dict_options = StateDictOptions(
560555
broadcast_from_rank0=True, # type: ignore[call-arg]
561556
full_state_dict=True,
562-
strict=strict, # gets ignored at the moment
557+
# must be set False to allow loading each param separately below
558+
strict=False,
563559
)
564560

565561
for submodule_name, submodule in module.named_modules():
566562
for param_name, _ in _named_parameters_and_buffers_to_load(submodule):
567563
full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}"
568564
if full_param_name not in state_dict:
569-
# Note: PyTorch does not currently respect the `strict` setting in state_dict_options!
570565
if not strict:
571566
continue
572567
raise KeyError(

0 commit comments

Comments
 (0)