diff --git a/pyproject.toml b/pyproject.toml index e6d08411b0f35..f5e9803563e92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,6 +181,7 @@ filterwarnings = [ # "error::DeprecationWarning", "error::FutureWarning", "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated + "ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning", ] xfail_strict = true junit_duration_report = "call" diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 0e1cc944a3492..307975e1619bd 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -19,9 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- - - +- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) - Set `_DeviceDtypeModuleMixin._device` from torch's default device function ([#21164](https://github.com/Lightning-AI/pytorch-lightning/pull/21164)) diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 3a33dac3335d1..db7578d9ca8c0 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -47,13 +47,20 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, map_location: Optional[Any] = None, weights_only: Optional[bool] = None + ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 90a5f62ba7413..c52ad6913e1e2 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -59,7 +59,10 @@ def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage + self, + path: _PATH, + map_location: Optional[Callable] = lambda storage, loc: storage, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -67,6 +70,11 @@ def load_checkpoint( path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. Returns: The loaded checkpoint. @@ -80,7 +88,7 @@ def load_checkpoint( if not fs.exists(path): raise FileNotFoundError(f"Checkpoint file not found: {path}") - return pl_load(path, map_location=map_location) + return pl_load(path, map_location=map_location, weights_only=weights_only) @override def remove_checkpoint(self, path: _PATH) -> None: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index fe72db20e2b85..209ddaaf57548 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -458,6 +458,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -483,7 +484,7 @@ def load_checkpoint( # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from # a consolidated checkpoint path = self.broadcast(path) - return super().load_checkpoint(path=path, state=state, strict=strict) + return super().load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) if not state: raise ValueError( diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index baaee74af0ec9..f42ade7484395 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -516,6 +516,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: @@ -586,7 +587,7 @@ def load_checkpoint( optim.load_state_dict(flattened_osd) # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) + metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) for key in requested_metadata_keys: diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 0d49ddf91a0bc..677584668975e 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -275,6 +275,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: @@ -295,7 +296,7 @@ def load_checkpoint( f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}." ) - return _load_checkpoint(path=path, state=state, strict=strict) + return _load_checkpoint(path=path, state=state, strict=strict, weights_only=weights_only) def _setup_distributed(self) -> None: reset_seed() @@ -411,6 +412,7 @@ def _load_checkpoint( state: dict[str, Union[Module, Optimizer, Any]], strict: bool = True, optimizer_states_from_list: bool = False, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, @@ -449,7 +451,7 @@ def _load_checkpoint( set_optimizer_state_dict(module, optim, optim_state_dict=optim_state[optim_key], options=state_dict_options) # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) + metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) for key in requested_metadata_keys: @@ -461,7 +463,7 @@ def _load_checkpoint( return metadata if _is_full_checkpoint(path): - checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False) + checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=weights_only) _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict) state_dict_options = StateDictOptions( diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 0b2b373acb7bc..b368f626c3b11 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -310,6 +310,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. @@ -330,7 +331,7 @@ def load_checkpoint( """ torch.cuda.empty_cache() - checkpoint = self.checkpoint_io.load_checkpoint(path) + checkpoint = self.checkpoint_io.load_checkpoint(path, weights_only=weights_only) if not state: return checkpoint diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 3fa9e40f4b4bd..51b528eff26ff 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -516,6 +516,7 @@ def load_checkpoint( path: _PATH, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, + weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Given a folder, load the contents from a checkpoint and restore the state of the given objects. @@ -608,7 +609,7 @@ def load_checkpoint( ) if "model" not in state or not isinstance(model := state["model"], torch.nn.Module): raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.") - full_ckpt = torch.load(path) + full_ckpt = torch.load(path, weights_only=weights_only) model.load_state_dict(full_ckpt.pop("model"), strict=strict) return full_ckpt diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 637dfcd9b1671..54b18fb6ce3b0 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -17,7 +17,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Union +from typing import IO, Any, Optional, Union import fsspec import fsspec.utils @@ -34,13 +34,18 @@ def _load( path_or_url: Union[IO, _PATH], map_location: _MAP_LOCATION_TYPE = None, - weights_only: bool = False, + weights_only: Optional[bool] = None, ) -> Any: """Loads a checkpoint. Args: path_or_url: Path or URL of the checkpoint. map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive + types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. """ if not isinstance(path_or_url, (str, Path)): @@ -51,6 +56,13 @@ def _load( weights_only=weights_only, ) if str(path_or_url).startswith("http"): + if weights_only is None: + weights_only = False + log.debug( + f"Defaulting to `weights_only=False` for remote checkpoint: {path_or_url}." + f" If loading a checkpoint from an untrustted source, we recommend using `weights_only=True`." + ) + return torch.hub.load_state_dict_from_url( str(path_or_url), map_location=map_location, # type: ignore[arg-type] @@ -70,7 +82,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: dict[str, Any], filepath: _PATH) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 70239baac0e6d..5655f2674638e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -35,6 +35,7 @@ _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1bba5e4ca0da7..e4936e8aca0a2 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072)) + + - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index ff84c2fd8b199..07ec02ef87bd8 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -177,6 +177,7 @@ def load_from_checkpoint( checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the @@ -206,6 +207,11 @@ def load_from_checkpoint( If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace` and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat ``hparams`` as :class:`~dict`. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved hyperparameter values. @@ -242,6 +248,7 @@ def load_from_checkpoint( map_location=map_location, hparams_file=hparams_file, strict=None, + weights_only=weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 85f631ee40f75..37b07f025f8e9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1690,6 +1690,7 @@ def load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Self: r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments @@ -1723,6 +1724,11 @@ def load_from_checkpoint( strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. Defaults to ``True`` unless ``LightningModule.strict_loading`` is set, in which case it defaults to the value of ``LightningModule.strict_loading``. + weights_only: If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other + primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use + ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using + ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. @@ -1778,6 +1784,7 @@ def load_from_checkpoint( map_location, hparams_file, strict, + weights_only, **kwargs, ) return cast(Self, loaded) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 21fd3912f7849..391e9dd5d0f25 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -56,11 +56,13 @@ def _load_from_checkpoint( map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: Optional[bool] = None, + weights_only: Optional[bool] = None, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: map_location = map_location or _default_map_location + with pl_legacy_patch(): - checkpoint = pl_load(checkpoint_path, map_location=map_location) + checkpoint = pl_load(checkpoint_path, map_location=map_location, weights_only=weights_only) # convert legacy checkpoints to the new format checkpoint = _pl_migrate_checkpoint( diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index c5253f77cdedb..369360590878d 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -659,12 +659,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op ) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) - return super().load_checkpoint(checkpoint_path) + return super().load_checkpoint(checkpoint_path, weights_only) _validate_checkpoint_directory(checkpoint_path) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 3fbd0f9cd5f0a..9706c8a64e61b 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -583,7 +583,7 @@ def save_checkpoint( raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) @@ -624,7 +624,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: optim.load_state_dict(flattened_osd) # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) + metadata = torch.load(path / _METADATA_FILENAME, weights_only=weights_only) return metadata if _is_full_checkpoint(path): diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index e0286dbe2e0e6..f3165a08e6bdd 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -329,7 +329,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { @@ -342,6 +342,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: state=state, strict=self.lightning_module.strict_loading, optimizer_states_from_list=True, + weights_only=weights_only, ) def _setup_distributed(self) -> None: diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 16b16a4927513..0a00cb28af15e 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,9 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH, weights_only: Optional[bool] = None) -> dict[str, Any]: torch.cuda.empty_cache() - return self.checkpoint_io.load_checkpoint(checkpoint_path) + return self.checkpoint_io.load_checkpoint(checkpoint_path, weights_only=weights_only) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: assert self.lightning_module is not None diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 7f97a2f54bf19..ae5038b2022d2 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -64,7 +64,7 @@ def _hpc_resume_path(self) -> Optional[str]: return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" return None - def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: + def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. @@ -80,7 +80,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): - loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) + loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only=weights_only) self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _select_ckpt_path( @@ -230,7 +230,7 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.strategy.barrier("_CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: + def restore(self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: @@ -244,7 +244,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_start(checkpoint_path) + self.resume_start(checkpoint_path, weights_only=weights_only) # restore module states self.restore_datamodule() @@ -403,18 +403,22 @@ def restore_lr_schedulers(self) -> None: for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): config.scheduler.load_state_dict(lrs_state) - def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: + def _restore_modules_and_callbacks( + self, checkpoint_path: Optional[_PATH] = None, weights_only: Optional[bool] = None + ) -> None: # restore modules after setup - self.resume_start(checkpoint_path) + self.resume_start(checkpoint_path, weights_only=weights_only) self.restore_model() self.restore_datamodule() self.restore_callbacks() - def dump_checkpoint(self, weights_only: bool = False) -> dict: + def dump_checkpoint(self, weights_only: Optional[bool] = None) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: - weights_only: saving model weights only + weights_only: If True, only saves model and loops state_dict objects. If False, + additionally saves callbacks, optimizers, schedulers, and precision plugin states. + Return: structured dictionary: { 'epoch': training epoch @@ -446,6 +450,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: "loops": self._get_loops_state_dict(), } + if weights_only is None: + weights_only = False + log.info("`weights_only` was not set, defaulting to `False`.") + if not weights_only: # dump callbacks checkpoint["callbacks"] = call._call_callbacks_state_dict(trainer) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..f2f59e396ab23 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -526,6 +526,7 @@ def fit( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, + weights_only: Optional[bool] = None, ) -> None: r"""Runs the full optimization routine. @@ -556,6 +557,12 @@ def fit( - ``'registry:version:v2'``: uses the default model set with ``Trainer(..., model_registry="my-model")`` and version 'v2' + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + Raises: TypeError: If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than @@ -573,7 +580,14 @@ def fit( self.training = True self.should_stop = False call._call_and_handle_interrupt( - self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + self, + self._fit_impl, + model, + train_dataloaders, + val_dataloaders, + datamodule, + ckpt_path, + weights_only, ) def _fit_impl( @@ -583,6 +597,7 @@ def _fit_impl( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, ckpt_path: Optional[_PATH] = None, + weights_only: Optional[bool] = None, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") @@ -610,7 +625,7 @@ def _fit_impl( model_provided=True, model_connected=self.lightning_module is not None, ) - self._run(model, ckpt_path=ckpt_path) + self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) assert self.state.stopped self.training = False @@ -623,6 +638,7 @@ def validate( ckpt_path: Optional[_PATH] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the validation set. @@ -643,6 +659,12 @@ def validate( datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + For more information about multiple dataloaders, see this :ref:`section `. Returns: @@ -676,7 +698,7 @@ def validate( self.state.status = TrainerStatus.RUNNING self.validating = True return call._call_and_handle_interrupt( - self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule + self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule, weights_only ) def _validate_impl( @@ -686,6 +708,7 @@ def _validate_impl( ckpt_path: Optional[_PATH] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK @@ -717,7 +740,7 @@ def _validate_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) # remove the tensors from the validation results results = convert_tensors_to_scalars(results) @@ -733,6 +756,7 @@ def test( ckpt_path: Optional[_PATH] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. @@ -754,6 +778,12 @@ def test( datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + For more information about multiple dataloaders, see this :ref:`section `. Returns: @@ -787,7 +817,7 @@ def test( self.state.status = TrainerStatus.RUNNING self.testing = True return call._call_and_handle_interrupt( - self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule + self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule, weights_only ) def _test_impl( @@ -797,6 +827,7 @@ def _test_impl( ckpt_path: Optional[_PATH] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + weights_only: Optional[bool] = None, ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK @@ -828,7 +859,7 @@ def _test_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) # remove the tensors from the test results results = convert_tensors_to_scalars(results) @@ -844,6 +875,7 @@ def predict( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, + weights_only: Optional[bool] = None, ) -> Optional[_PREDICT_OUTPUT]: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. @@ -866,6 +898,12 @@ def predict( Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. + weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain + ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains + an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we + recommend using ``weights_only=True``. For more information, please refer to the + `PyTorch Developer Notes on Serialization Semantics `_. + For more information about multiple dataloaders, see this :ref:`section `. Returns: @@ -899,7 +937,14 @@ def predict( self.state.status = TrainerStatus.RUNNING self.predicting = True return call._call_and_handle_interrupt( - self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path + self, + self._predict_impl, + model, + dataloaders, + datamodule, + return_predictions, + ckpt_path, + weights_only, ) def _predict_impl( @@ -909,6 +954,7 @@ def _predict_impl( datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[_PATH] = None, + weights_only: Optional[bool] = None, ) -> Optional[_PREDICT_OUTPUT]: # -------------------- # SETUP HOOK @@ -939,7 +985,7 @@ def _predict_impl( ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path, weights_only=weights_only) assert self.state.stopped self.predicting = False @@ -947,7 +993,10 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None + self, + model: "pl.LightningModule", + ckpt_path: Optional[_PATH] = None, + weights_only: Optional[bool] = None, ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -992,7 +1041,7 @@ def _run( # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") - self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path, weights_only) # reset logger connector self._logger_connector.reset_results() @@ -1386,7 +1435,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: self._checkpoint_connector._user_managed = bool(ckpt_path) def save_checkpoint( - self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None + self, filepath: _PATH, weights_only: Optional[bool] = None, storage_options: Optional[Any] = None ) -> None: r"""Runs routine to create a checkpoint. diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index ec2aabf559dc7..420f711678808 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -53,9 +53,9 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path], weights_only: bool) -> dict[str, Any]: assert self.setup_called == restore_after_pre_setup - return super().load_checkpoint(checkpoint_path) + return super().load_checkpoint(checkpoint_path, weights_only) model = BoringModel() trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 34749087bfb97..fec835f199e0b 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -18,6 +18,7 @@ import pytest from lightning_utilities.test.warning import no_warning_call +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -132,7 +133,8 @@ def test_resume_callback_state_saved_by_type_stateful(tmp_path): callback = OldStatefulCallback(state=222) trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback]) - trainer.fit(model, ckpt_path=ckpt_path) + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None + trainer.fit(model, ckpt_path=ckpt_path, weights_only=weights_only) assert callback.state == 111 diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index abcd302149fcf..5786a3339a7fa 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -25,6 +25,7 @@ from torch.optim.swa_utils import SWALR from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.pytorch import Trainer from lightning.pytorch.callbacks import StochasticWeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset @@ -173,8 +174,9 @@ def train_with_swa( devices=devices, ) + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None with _backward_patch(trainer): - trainer.fit(model) + trainer.fit(model, weights_only=weights_only) # check the model is the expected assert trainer.lightning_module == model @@ -307,8 +309,9 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals } trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"): - trainer.fit(model) + trainer.fit(model, weights_only=weights_only) checkpoint_dir = Path(tmp_path) / "checkpoints" checkpoint_files = os.listdir(checkpoint_dir) @@ -318,7 +321,7 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) with _backward_patch(trainer): - trainer.fit(resume_model, ckpt_path=ckpt_path) + trainer.fit(resume_model, ckpt_path=ckpt_path, weights_only=weights_only) class CustomSchedulerModel(SwaTestModel): diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 006a123356c98..a5ad77cf25c1a 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -18,6 +18,7 @@ import pytest import torch +from packaging.version import Version import lightning.pytorch as pl from lightning.pytorch import Callback, Trainer @@ -45,7 +46,12 @@ def test_load_legacy_checkpoints(tmp_path, pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] - model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24) + if pl_version == "local": + pl_version = pl.__version__ + + weights_only = Version(pl_version) >= Version("1.5.0") + + model = ClassificationModel.load_from_checkpoint(path_ckpt, num_features=24, weights_only=weights_only) trainer = Trainer(default_root_dir=tmp_path) dm = ClassifDataModule(num_features=24, length=6000, batch_size=128, n_clusters_per_class=2, n_informative=8) res = trainer.test(model, datamodule=dm) @@ -73,13 +79,18 @@ def test_legacy_ckpt_threading(pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] + # legacy load utility added in 1.5.0 (see https://github.com/Lightning-AI/pytorch-lightning/pull/9166) + if pl_version == "local": + pl_version = pl.__version__ + weights_only = not Version(pl_version) < Version("1.5.0") + def load_model(): import torch from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt, weights_only=False) + _ = torch.load(path_ckpt, weights_only=weights_only) with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) @@ -94,9 +105,14 @@ def load_model(): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) -def test_resume_legacy_checkpoints(tmp_path, pl_version: str): +def test_resume_legacy_checkpoints(monkeypatch, tmp_path, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) with patch("sys.path", [PATH_LEGACY] + sys.path): + if pl_version == "local": + pl_version = pl.__version__ + if Version(pl_version) < Version("1.5.0"): + monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' path_ckpt = path_ckpts[-1] diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 92a07f0a3d05e..620d7d5eb896e 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -19,6 +19,7 @@ from argparse import Namespace from dataclasses import dataclass, field from enum import Enum +from typing import Optional from unittest import mock import cloudpickle @@ -29,6 +30,7 @@ from lightning_utilities.test.warning import no_warning_call from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.datamodule import LightningDataModule @@ -94,7 +96,9 @@ def __init__(self, hparams, *my_args, **my_kwargs): # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False): +def _run_standard_hparams_test( + tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: Optional[bool] = None +): """Tests for the existence of an arg 'test_arg=14'.""" obj = datamodule if issubclass(cls, LightningDataModule) else model @@ -108,19 +112,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert obj2.hparams.test_arg == 14 assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only) assert obj3.hparams.test_arg == 78 return raw_checkpoint_path @@ -175,8 +180,10 @@ def test_omega_conf_hparams(tmp_path, cls): assert isinstance(obj.hparams, Container) # run standard test suite - raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule) - obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + # weights_only=False as omegaconf.DictConfig is not an allowed global by default + raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False) + assert isinstance(obj2.hparams, Container) # config specific tests @@ -367,13 +374,17 @@ class DictConfSubClassBoringModel: ... BoringModelWithMixinAndInit, ], ) -def test_collect_init_arguments(tmp_path, cls): +def test_collect_init_arguments(tmp_path, cls: BoringModel): """Test that the model automatically saves the arguments passed into the constructor.""" extra_args = {} + weights_only = True + if cls is AggSubClassBoringModel: extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) + weights_only = False elif cls is DictConfSubClassBoringModel: extra_args.update(dict_conf=OmegaConf.create({"my_param": "anything"})) + weights_only = False model = cls(**extra_args) assert model.hparams.batch_size == 64 @@ -392,12 +403,12 @@ def test_collect_init_arguments(tmp_path, cls): raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False) + raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179 # verify that model loads correctly - model = cls.load_from_checkpoint(raw_checkpoint_path) + model = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only) assert model.hparams.batch_size == 179 if isinstance(model, AggSubClassBoringModel): @@ -408,7 +419,7 @@ def test_collect_init_arguments(tmp_path, cls): assert model.hparams.dict_conf["my_param"] == "anything" # verify that we can overwrite whatever we want - model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99, weights_only=weights_only) assert model.hparams.batch_size == 99 @@ -738,8 +749,9 @@ def test_model_with_fsspec_as_parameter(tmp_path): trainer = Trainer( default_root_dir=tmp_path, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1 ) - trainer.fit(model) - trainer.test() + weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None + trainer.fit(model, weights_only=weights_only) + trainer.test(weights_only=weights_only) @pytest.mark.xfail( @@ -781,7 +793,7 @@ def __init__(self, args_0, args_1, args_2, kwarg_1=None): logger=False, ) trainer.fit(model) - _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path, weights_only=False) @pytest.mark.parametrize("ignore", ["arg2", ("arg2", "arg3")]) diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index f7a76079cfca2..0f66f215f6864 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -32,7 +32,9 @@ class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: + def load_checkpoint( + self, path: _PATH, storage_options: Optional[Any] = None, weights_only: bool = True + ) -> dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: @@ -67,7 +69,7 @@ def test_checkpoint_plugin_called(tmp_path): assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt")) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"), weights_only=None) checkpoint_plugin.reset_mock() ck = ModelCheckpoint(dirpath=tmp_path, save_last=True) @@ -95,7 +97,7 @@ def test_checkpoint_plugin_called(tmp_path): trainer.test(model, ckpt_path=ck.last_model_path) checkpoint_plugin.load_checkpoint.assert_called_once() - checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt")) + checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"), weights_only=None) @pytest.mark.flaky(reruns=3) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index e79f1b78e02da..6ff4bee264a7b 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1258,7 +1258,7 @@ def test_lightning_cli_model_short_arguments(): ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) - run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), @@ -1286,7 +1286,7 @@ def test_lightning_cli_datamodule_short_arguments(): ): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) - run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) + run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), @@ -1307,7 +1307,7 @@ def test_lightning_cli_datamodule_short_arguments(): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) - run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY, ANY) with ( mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), @@ -1483,7 +1483,7 @@ def test_lightning_cli_config_with_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 @@ -1499,7 +1499,7 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1512,7 +1512,7 @@ def test_lightning_cli_config_before_subcommand(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=None) assert cli.trainer.limit_val_batches == 1 save_config_callback = cli.trainer.callbacks[0] @@ -1530,7 +1530,7 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 with ( @@ -1539,7 +1539,7 @@ def test_lightning_cli_config_before_subcommand_two_configs(): ): cli = LightningCLI(BoringModel) - validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") + validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=None) assert cli.trainer.limit_val_batches == 1 @@ -1551,7 +1551,7 @@ def test_lightning_cli_config_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") + test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=None) assert cli.trainer.limit_test_batches == 1 @@ -1564,7 +1564,9 @@ def test_lightning_cli_config_before_and_after_subcommand(): ): cli = LightningCLI(BoringModel) - test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") + test_mock.assert_called_once_with( + cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar", weights_only=None + ) assert cli.trainer.limit_test_batches == 1 assert cli.trainer.fast_dev_run == 1