Skip to content

strict = False does not work when the checkpoint is distributed #20274

@NathanGodey

Description

@NathanGodey

Bug description

When loading a sharded checkpoint with:

fabric.load(ckpt_path, state, strict = False)

the _distributed_checkpoint_load function called in the FSDPStrategy will raise an error if a checkpoint misses a key from the model in state, which should not be the case as strict = False.

A fix could be to take advantage of the DefaultLoadPlanner in torch.distributed.checkpoint.load, setting the allow_partial_load argument to the opposite of strict.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

[rank7]: Traceback (most recent call last):
[rank7]:   File "my_codebase/train_fabric.py", line 226, in <module>
[rank7]:     main(**vars(args))
[rank7]:   File "my_codebase/train_fabric.py", line 148, in main
[rank7]:     fabric.load(ckpt_path, state, strict = strict_mode)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 773, in load
[rank7]:     remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
[rank7]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 570, in load_checkpoint
[rank7]:     _distributed_checkpoint_load(module_state, path)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 886, in _distributed_checkpoint_load
[rank7]:     load(module_state, checkpoint_id=path)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 434, in inner_func
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
[rank7]:     _load_state_dict(
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
[rank7]:     central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank7]:     raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]:     return create_default_local_load_plan(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.lm_model.lm_head.weight.
[rank7]: Traceback (most recent call last): (RANK 1)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]:     return create_default_local_load_plan(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.lm_model.lm_head.weight.
[rank7]: Traceback (most recent call last): (RANK 2)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]:     return create_default_local_load_plan(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.my_key.

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 2.4.0+rocm6.0
#- Python version (e.g., 3.12): 3.11
</details>


### More info

_No response_

cc @lantiga @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcheckpointingRelated to checkpointingfabriclightning.fabric.Fabricver: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions