Skip to content

Commit 653dd6f

Browse files
committed
add weights_only args to strategies
1 parent a95f9ba commit 653dd6f

File tree

8 files changed

+14
-10
lines changed

8 files changed

+14
-10
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def load_checkpoint(
458458
path: _PATH,
459459
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
460460
strict: bool = True,
461+
weights_only: Optional[bool] = None,
461462
) -> dict[str, Any]:
462463
"""Load the contents from a checkpoint and restore the state of the given objects.
463464
@@ -483,7 +484,7 @@ def load_checkpoint(
483484
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
484485
# a consolidated checkpoint
485486
path = self.broadcast(path)
486-
return super().load_checkpoint(path=path, state=state, strict=strict)
487+
return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
487488

488489
if not state:
489490
raise ValueError(

src/lightning/fabric/strategies/fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def load_checkpoint(
516516
path: _PATH,
517517
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
518518
strict: bool = True,
519+
weights_only: Optional[bool] = None,
519520
) -> dict[str, Any]:
520521
"""Load the contents from a checkpoint and restore the state of the given objects."""
521522
if not state:
@@ -586,7 +587,7 @@ def load_checkpoint(
586587
optim.load_state_dict(flattened_osd)
587588

588589
# Load metadata (anything not a module or optimizer)
589-
metadata = torch.load(path / _METADATA_FILENAME)
590+
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
590591
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
591592
_validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
592593
for key in requested_metadata_keys:

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def load_checkpoint(
275275
path: _PATH,
276276
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
277277
strict: bool = True,
278+
weights_only: Optional[bool] = None,
278279
) -> dict[str, Any]:
279280
"""Load the contents from a checkpoint and restore the state of the given objects."""
280281
if not state:
@@ -295,7 +296,7 @@ def load_checkpoint(
295296
f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}."
296297
)
297298

298-
return _load_checkpoint(path=path, state=state, strict=strict)
299+
return _load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only)
299300

300301
def _setup_distributed(self) -> None:
301302
reset_seed()

src/lightning/fabric/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def load_checkpoint(
310310
path: _PATH,
311311
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
312312
strict: bool = True,
313-
weights_only: bool = False,
313+
weights_only: Optional[bool] = None,
314314
) -> dict[str, Any]:
315315
"""Load the contents from a checkpoint and restore the state of the given objects.
316316

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def load_checkpoint(
516516
path: _PATH,
517517
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
518518
strict: bool = True,
519+
weights_only: Optional[bool] = None,
519520
) -> dict[str, Any]:
520521
"""Given a folder, load the contents from a checkpoint and restore the state of the given objects.
521522
@@ -608,7 +609,7 @@ def load_checkpoint(
608609
)
609610
if "model" not in state or not isinstance(model := state["model"], torch.nn.Module):
610611
raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.")
611-
full_ckpt = torch.load(path)
612+
full_ckpt = torch.load(path, weights_only=weights_only)
612613
model.load_state_dict(full_ckpt.pop("model"), strict=strict)
613614
return full_ckpt
614615

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,12 +659,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
659659
)
660660

661661
@override
662-
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
662+
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
663663
if self.load_full_weights and self.zero_stage_3:
664664
# Broadcast to ensure we load from the rank 0 checkpoint
665665
# This doesn't have to be the case when using deepspeed sharded checkpointing
666666
checkpoint_path = self.broadcast(checkpoint_path)
667-
return super().load_checkpoint(checkpoint_path)
667+
return super().load_checkpoint(checkpoint_path, weights_only)
668668

669669
_validate_checkpoint_directory(checkpoint_path)
670670

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def save_checkpoint(
583583
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
584584

585585
@override
586-
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
586+
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
587587
# broadcast the path from rank 0 to ensure all the states are loaded from a common path
588588
path = Path(self.broadcast(checkpoint_path))
589589

@@ -624,7 +624,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
624624
optim.load_state_dict(flattened_osd)
625625

626626
# Load metadata (anything not a module or optimizer)
627-
metadata = torch.load(path / _METADATA_FILENAME)
627+
metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only)
628628
return metadata
629629

630630
if _is_full_checkpoint(path):

src/lightning/pytorch/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
363363
"""Returns the pure LightningModule without potential wrappers."""
364364
return self._lightning_module
365365

366-
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool) -> dict[str, Any]:
366+
def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]:
367367
torch.cuda.empty_cache()
368368
return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only)
369369

0 commit comments

Comments
 (0)