diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index a0c2f49bfee60..34b89fc208c7f 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -29,23 +29,23 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v5 - - uses: actions/setup-python@v6 - with: - python-version: "3.11" - - name: Mypy cache - uses: actions/cache@v4 + - name: Install uv and set Python version + uses: astral-sh/setup-uv@v6 with: - path: .mypy_cache - key: mypy-${{ hashFiles('requirements/typing.txt') }} + python-version: "3.11" + # TODO: Avoid activating environment like this + # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment + activate-environment: true + enable-cache: true - name: Install dependencies env: FREEZE_REQUIREMENTS: 1 timeout-minutes: 20 run: | - pip install -e '.[pytorch-all,fabric-all]' -r requirements/typing.txt - pip list + uv pip install '.[pytorch-all,fabric-all]' -r requirements/typing.txt + uv pip list - name: Check typing run: mypy diff --git a/pyproject.toml b/pyproject.toml index bec03d8164ad4..c2d754b2250fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,7 @@ filterwarnings = [ # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated + "ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning", "ignore:The pynvml package is deprecated:FutureWarning", # Ignore pynvml deprecation warning, since it is not installed by PL directly ] xfail_strict = true diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index e949a9fac5a9b..0aa0ee3a5c402 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -10,8 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060)) - - - @@ -22,8 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- let `_get_default_process_group_backend_for_device` support more hardware platforms ( - [#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093)) +- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) +- let `_get_default_process_group_backend_for_device` support more hardware platforms ([#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093)) ### Fixed diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 3a33dac3335d1..db7578d9ca8c0 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None + ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 90a5f62ba7413..c52ad6913e1e2 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -59,7 +59,10 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage + self, + path: _PATH, + map_location: Optional[Callable] = lambda storage, loc: storage, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -67,6 +70,11 @@ def load_checkpoint( path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. @@ -80,7 +88,7 @@ def load_checkpoint( if not fs.exists(path): raise FileNotFoundError(f"Checkpoint file not found: {path}") - return pl_load(path, map_location=map_location) + return pl_load(path, map_location=map_location, weights_only=weights_only) @override def remove_checkpoint(self, path: _PATH) -> None: diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 0b2b373acb7bc..fe766e0432b8e 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -310,6 +310,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: bool = False, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -330,7 +331,7 @@ def load_checkpoint( """ torch.cuda.empty_cache() - checkpoint = self.checkpoint_io.load_checkpoint(path) + checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only=weights_only) if not state: return checkpoint diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 637dfcd9b1671..54b18fb6ce3b0 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -17,7 +17,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Union +from typing import IO, Any, Optional, Union import fsspec import fsspec.utils @@ -34,13 +34,18 @@ def _load( path_or_url: Union[IO, _PATH], map_location: _MAP_LOCATION_TYPE = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> Any: """Loads a checkpoint. Args: path_or_url: Path or URL of the checkpoint. map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive + types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. """ if not isinstance(path_or_url, (str, Path)): @@ -51,6 +56,13 @@ def _load( weights_only=weights_only, ) if str(path_or_url).startswith("http"): + if weights_only is None: + weights_only = False + log.debug( + f"Defaulting to `weights_only=False` for remote checkpoint: {path_or_url}." + f" If loading a checkpoint from an untrustted source, we recommend using `weights_only=True`." + ) + return torch.hub.load_state_dict_from_url( str(path_or_url), map_location=map_location, # type: ignore[arg-type] @@ -70,7 +82,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..5655f2674638e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -35,6 +35,7 @@ _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fd6f0a34ccec4..9d4e7abf8b6f5 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -27,7 +27,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) +- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) + + +- ### Removed diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index ff84c2fd8b199..07ec02ef87bd8 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -177,6 +177,7 @@ def load_from_checkpoint( checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the @@ -206,6 +207,11 @@ def load_from_checkpoint( If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace` and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat ``hparams`` as :class:`~dict`. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved hyperparameter values. @@ -242,6 +248,7 @@ def load_from_checkpoint( map_location=map_location, hparams_file=hparams_file, strict=None, + weights_only=weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 85f631ee40f75..37b07f025f8e9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1690,6 +1690,7 @@ def load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments @@ -1723,6 +1724,11 @@ def load_from_checkpoint( strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is set, in which case it defaults to the value of ``LightningModule.strict_loading``. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. @@ -1778,6 +1784,7 @@ def load_from_checkpoint( map_location, hparams_file, strict, + weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 21fd3912f7849..391e9dd5d0f25 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -56,11 +56,13 @@ def _load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: map_location = map_location or _default_map_location + with pl_legacy_patch(): - checkpoint = pl_load(checkpoint_path, map_location=map_location) + checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only) # convert legacy checkpoints to the new format checkpoint = _pl_migrate_checkpoint( diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 16b16a4927513..4e5dc33f62672 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,9 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: bool) -> dict[str, Any]: torch.cuda.empty_cache() - return self.checkpoint_io.load_checkpoint(checkpoint_path) + return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: assert self.lightning_module is not None diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 7f97a2f54bf19..863fd265a6cfd 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]: return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" return None - def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: + def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. @@ -80,7 +80,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): - loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) + loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only) self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _select_ckpt_path( @@ -403,9 +403,11 @@ def restore_lr_schedulers(self) -> None: for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): config.scheduler.load_state_dict(lrs_state) - def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: + def _restore_modules_and_callbacks( + self, checkpoint_path: Optional[_PATH] = None, weights_only: bool = False + ) -> None: # restore modules after setup - self.resume_start(checkpoint_path) + self.resume_start(checkpoint_path, weights_only) self.restore_model() self.restore_datamodule() self.restore_callbacks() @@ -414,7 +416,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: - weights_only: saving model weights only + weights_only: If True, only saves model and loops state_dict objects. If False, + additionally saves callbacks, optimizers, schedulers, and precision plugin states. + Return: structured dictionary: { 'epoch': training epoch diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..739ecc369609f 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -526,6 +526,7 @@ def fit( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> None: r"""Runs the full optimization routine. @@ -573,7 +574,14 @@ def fit( self.training = True self.should_stop = False call._call_and_handle_interrupt( - self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + self, + self._fit_impl, + model, + train_dataloaders, + val_dataloaders, + datamodule, + ckpt_path, + weights_only, ) def _fit_impl( @@ -583,6 +591,7 @@ def _fit_impl( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") @@ -610,7 +619,7 @@ def _fit_impl( model_provided=True, model_connected=self.lightning_module is not None, ) - self._run(model, ckpt_path=ckpt_path) + self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) assert self.state.stopped self.training = False @@ -621,6 +630,7 @@ def validate( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: @@ -676,7 +686,7 @@ def validate( self.state.status = TrainerStatus.RUNNING self.validating = True return call._call_and_handle_interrupt( - self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule + self, self._validate_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule ) def _validate_impl( @@ -684,6 +694,7 @@ def _validate_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: @@ -717,7 +728,7 @@ def _validate_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) # remove the tensors from the validation results results = convert_tensors_to_scalars(results) @@ -731,6 +742,7 @@ def test( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: @@ -787,7 +799,7 @@ def test( self.state.status = TrainerStatus.RUNNING self.testing = True return call._call_and_handle_interrupt( - self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule + self, self._test_impl, model, dataloaders, ckpt_path, weights_only, verbose, datamodule ) def _test_impl( @@ -795,6 +807,7 @@ def _test_impl( model: Optional["pl.LightningModule"] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: @@ -828,7 +841,7 @@ def _test_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) # remove the tensors from the test results results = convert_tensors_to_scalars(results) @@ -844,6 +857,7 @@ def predict( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> Optional[_PREDICT_OUTPUT]: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. @@ -899,7 +913,7 @@ def predict( self.state.status = TrainerStatus.RUNNING self.predicting = True return call._call_and_handle_interrupt( - self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path + self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only ) def _predict_impl( @@ -909,6 +923,7 @@ def _predict_impl( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, + weights_only: bool = False, ) -> Optional[_PREDICT_OUTPUT]: # -------------------- # SETUP HOOK @@ -939,7 +954,7 @@ def _predict_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) assert self.state.stopped self.predicting = False @@ -947,7 +962,7 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None + self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None, weights_only: bool = False ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -992,7 +1007,7 @@ def _run( # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") - self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path, weights_only) # reset logger connector self._logger_connector.reset_results() diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index ec2aabf559dc7..420f711678808 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -53,9 +53,9 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path], weights_only: bool) -> dict[str, Any]: assert self.setup_called == restore_after_pre_setup - return super().load_checkpoint(checkpoint_path) + return super().load_checkpoint(checkpoint_path, weights_only) model = BoringModel() trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 006a123356c98..a5ad77cf25c1a 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -18,6 +18,7 @@ import pytest import torch +from packaging.version import Version import lightning.pytorch as pl from lightning.pytorch import Callback, Trainer @@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) + if pl_version == "local": + pl_version = pl.__version__ + + weights_only = Version(pl_version) >= Version("1.5.0") + + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) trainer = Trainer(default_root_dir=tmp_path) dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) res = trainer.test(model, datamodule=dm) @@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) + if pl_version == "local": + pl_version = pl.__version__ + weights_only = not Version(pl_version) < Version("1.5.0") + def load_model(): import torch from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt, weights_only=False) + _ = torch.load(path_ckpt, weights_only=weights_only) with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) @@ -94,9 +105,14 @@ def load_model(): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) -def test_resume_legacy_checkpoints(tmp_path, pl_version: str): +def test_resume_legacy_checkpoints(monkeypatch, tmp_path, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) with patch("sys.path", [PATH_LEGACY] + sys.path): + if pl_version == "local": + pl_version = pl.__version__ + if Version(pl_version) < Version("1.5.0"): + monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 575bcadadc404..d0c72721ce1be 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -19,6 +19,7 @@ from argparse import Namespace from dataclasses import dataclass, field from enum import Enum +from typing import Optional from unittest import mock import cloudpickle @@ -94,7 +95,9 @@ def __init__(self, hparams, *my_args, **my_kwargs): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False): +def _run_standard_hparams_test( + tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: Optional[bool] = None +): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model @@ -108,19 +111,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert obj2.hparams.test_arg == 14 assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only) assert obj3.hparams.test_arg == 78 return raw_checkpoint_path @@ -175,8 +179,10 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite - raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + # weights_only=False as omegaconf.DictConfig is not an allowed global by default + raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False) + assert isinstance(obj2.hparams, Container) # config specific tests @@ -367,13 +373,17 @@ class DictConfSubClassBoringModel: ... BoringModelWithMixinAndInit, ], ) -def test_collect_init_arguments(tmp_path, cls): +def test_collect_init_arguments(tmp_path, cls: BoringModel): """Test that the model automatically saves the arguments passed into the constructor.""" extra_args = {} + weights_only = True + if cls is AggSubClassBoringModel: extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) + weights_only = False elif cls is DictConfSubClassBoringModel: extra_args.update(dict_conf=OmegaConf.create({"my_param": "anything"})) + weights_only = False model = cls(**extra_args) assert model.hparams.batch_size == 64 @@ -392,12 +402,12 @@ def test_collect_init_arguments(tmp_path, cls): raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179 # verify that model loads correctly - model = cls.load_from_checkpoint(raw_checkpoint_path) + model = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert model.hparams.batch_size == 179 if isinstance(model, AggSubClassBoringModel): @@ -408,7 +418,7 @@ def test_collect_init_arguments(tmp_path, cls): assert model.hparams.dict_conf["my_param"] == "anything" # verify that we can overwrite whatever we want - model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99, weights_only=weights_only) assert model.hparams.batch_size == 99 @@ -781,7 +791,7 @@ def __init__(self, args_0, args_1, args_2, kwarg_1=None): logger=False, ) trainer.fit(model) - _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path, weights_only=False) @pytest.mark.parametrize("ignore", ["arg2", ("arg2", "arg3")]) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index f7a76079cfca2..c6be86baf3bc6 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -32,7 +32,9 @@ class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True + ) -> dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: @@ -67,7 +69,7 @@ def test_checkpoint_plugin_called(tmp_path): assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt")) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"), weights_only=False) checkpoint_plugin.reset_mock() ck = ModelCheckpoint(dirpath=tmp_path, save_last=True) @@ -95,7 +97,7 @@ def test_checkpoint_plugin_called(tmp_path): trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt")) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"), weights_only=False) @pytest.mark.flaky(reruns=3) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 248852f4cf1f3..c25f75658341b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1222,7 +1222,7 @@ def test_lightning_cli_model_short_arguments(): ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) - run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), @@ -1250,7 +1250,7 @@ def test_lightning_cli_datamodule_short_arguments(): ): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) - run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) + run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), @@ -1271,7 +1271,7 @@ def test_lightning_cli_datamodule_short_arguments(): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) - run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), @@ -1447,7 +1447,7 @@ def test_lightning_cli_config_with_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False) assert cli.trainer.limit_test_batches == 1 @@ -1463,7 +1463,9 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False + ) assert cli.trainer.limit_test_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1476,7 +1478,7 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False) assert cli.trainer.limit_val_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1494,7 +1496,9 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False + ) assert cli.trainer.limit_test_batches == 1 with ( @@ -1503,7 +1507,7 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False) assert cli.trainer.limit_val_batches == 1 @@ -1515,7 +1519,7 @@ def test_lightning_cli_config_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False) assert cli.trainer.limit_test_batches == 1 @@ -1528,7 +1532,9 @@ def test_lightning_cli_config_before_and_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar", weights_only=False + ) assert cli.trainer.limit_test_batches == 1 assert cli.trainer.fast_dev_run == 1