Skip to content

Commit ea6773e

Browse files
committed
weights_only for fsdp load
1 parent b84a53d commit ea6773e

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def _load_checkpoint(
411411
state: dict[str, Union[Module, Optimizer, Any]],
412412
strict: bool = True,
413413
optimizer_states_from_list: bool = False,
414+
weights_only: bool = False,
414415
) -> dict[str, Any]:
415416
from torch.distributed.checkpoint.state_dict import (
416417
StateDictOptions,
@@ -449,7 +450,7 @@ def _load_checkpoint(
449450
set_optimizer_state_dict(module, optim, optim_state_dict=optim_state[optim_key], options=state_dict_options)
450451

451452
# Load metadata (anything not a module or optimizer)
452-
metadata = torch.load(path / _METADATA_FILENAME)
453+
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
453454
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
454455
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
455456
for key in requested_metadata_keys:
@@ -461,7 +462,7 @@ def _load_checkpoint(
461462
return metadata
462463

463464
if _is_full_checkpoint(path):
464-
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False)
465+
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=weights_only)
465466
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
466467

467468
state_dict_options = StateDictOptions(

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def save_checkpoint(
329329
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
330330

331331
@override
332-
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
332+
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool = False) -> dict[str, Any]:
333333
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
334334
path = Path(self.broadcast(checkpoint_path))
335335
state = {
@@ -342,6 +342,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
342342
state=state,
343343
strict=self.lightning_module.strict_loading,
344344
optimizer_states_from_list=True,
345+
weights_only=weights_only,
345346
)
346347

347348
def _setup_distributed(self) -> None:

0 commit comments

Comments
 (0)