-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingfabriclightning.fabric.Fabriclightning.fabric.Fabricver: 2.4.x
Description
Bug description
When loading a sharded checkpoint with:
fabric.load(ckpt_path, state, strict = False)
the _distributed_checkpoint_load
function called in the FSDPStrategy
will raise an error if a checkpoint misses a key from the model in state
, which should not be the case as strict = False
.
A fix could be to take advantage of the DefaultLoadPlanner in torch.distributed.checkpoint.load
, setting the allow_partial_load
argument to the opposite of strict
.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
[rank7]: Traceback (most recent call last):
[rank7]: File "my_codebase/train_fabric.py", line 226, in <module>
[rank7]: main(**vars(args))
[rank7]: File "my_codebase/train_fabric.py", line 148, in main
[rank7]: fabric.load(ckpt_path, state, strict = strict_mode)
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 773, in load
[rank7]: remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 570, in load_checkpoint
[rank7]: _distributed_checkpoint_load(module_state, path)
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 886, in _distributed_checkpoint_load
[rank7]: load(module_state, checkpoint_id=path)
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]: result = func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 434, in inner_func
[rank7]: return func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
[rank7]: _load_state_dict(
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
[rank7]: central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank7]: raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]: local_data = map_fun()
[rank7]: ^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]: result = func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]: local_plan = planner.create_local_plan()
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]: return create_default_local_load_plan(
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]: raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.lm_model.lm_head.weight.
[rank7]: Traceback (most recent call last): (RANK 1)
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]: local_data = map_fun()
[rank7]: ^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]: result = func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]: local_plan = planner.create_local_plan()
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]: return create_default_local_load_plan(
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]: raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.lm_model.lm_head.weight.
[rank7]: Traceback (most recent call last): (RANK 2)
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]: local_data = map_fun()
[rank7]: ^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]: result = func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]: local_plan = planner.create_local_plan()
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]: return create_default_local_load_plan(
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]: raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.my_key.
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 2.4.0+rocm6.0
#- Python version (e.g., 3.12): 3.11
</details>
### More info
_No response_
cc @lantiga @justusschock
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingfabriclightning.fabric.Fabriclightning.fabric.Fabricver: 2.4.x