diff --git a/Makefile b/Makefile index 426c18042994c..fadce73be389f 100644 --- a/Makefile +++ b/Makefile @@ -53,7 +53,7 @@ docs-fabric: clean sphinx-theme cd docs/source-fabric && $(MAKE) html --jobs $(nproc) docs-pytorch: clean sphinx-theme - pip install -e .[all] --quiet -r requirements/pytorch/docs.txt -r _notebooks/.actions/requires.txt + pip install -e .[all] --quiet -r requirements/pytorch/docs.txt cd docs/source-pytorch && $(MAKE) html --jobs $(nproc) update: diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt index 21287196933ea..b3725391271ae 100644 --- a/requirements/pytorch/docs.txt +++ b/requirements/pytorch/docs.txt @@ -4,4 +4,4 @@ nbformat # used for generate empty notebook ipython[notebook] <8.7.0 setuptools<58.0 # workaround for `error in ipython setup command: use_2to3 is invalid.` --r ../../_notebooks/.actions/requires.txt +#-r ../../_notebooks/.actions/requires.txt diff --git a/src/lightning/pytorch/trainer/__init__.py b/src/lightning/pytorch/trainer/__init__.py index cbed5dd4f1f20..f2e1b963306a1 100644 --- a/src/lightning/pytorch/trainer/__init__.py +++ b/src/lightning/pytorch/trainer/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""""" from lightning.fabric.utilities.seed import seed_everything from lightning.pytorch.trainer.trainer import Trainer diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index a60f907d9361b..3f107bd9a124a 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging import os from collections.abc import Sequence from datetime import timedelta from typing import Optional, Union +from lightning_utilities import module_available + import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks from lightning.pytorch.callbacks import ( @@ -91,7 +93,24 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None: " but found `ModelCheckpoint` in callbacks list." ) elif enable_checkpointing: - self.trainer.callbacks.append(ModelCheckpoint()) + if module_available("litmodels") and self.trainer._model_registry: + trainer_source = inspect.getmodule(self.trainer) + if trainer_source is None or not isinstance(trainer_source.__package__, str): + raise RuntimeError("Unable to determine the source of the trainer.") + # this need to imported based on the actual package lightning/pytorch_lightning + if "pytorch_lightning" in trainer_source.__package__: + from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint + else: + from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint + + model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry) + else: + rank_zero_info( + "You are using the plain ModelCheckpoint callback." + " Consider using LitModelCheckpoint which with seamless uploading to Model registry." + ) + model_checkpoint = ModelCheckpoint() + self.trainer.callbacks.append(model_checkpoint) def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: if not enable_model_summary: diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 71cc5a14686be..7f97a2f54bf19 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -19,6 +19,7 @@ import torch from fsspec.core import url_to_fs from fsspec.implementations.local import LocalFileSystem +from lightning_utilities import module_available from torch import Tensor import lightning.pytorch as pl @@ -33,6 +34,10 @@ from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning.pytorch.utilities.migration import pl_legacy_patch from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint +from lightning.pytorch.utilities.model_registry import ( + _is_registry, + find_model_local_ckpt_path, +) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) @@ -48,8 +53,7 @@ def __init__(self, trainer: "pl.Trainer") -> None: @property def _hpc_resume_path(self) -> Optional[str]: - dir_path_hpc = self.trainer.default_root_dir - dir_path_hpc = str(dir_path_hpc) + dir_path_hpc = str(self.trainer.default_root_dir) fs, path = url_to_fs(dir_path_hpc) if not _is_dir(fs, path): return None @@ -194,10 +198,17 @@ def _parse_ckpt_path( if not self._hpc_resume_path: raise ValueError( f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.' - " Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`" + f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`" ) ckpt_path = self._hpc_resume_path + elif _is_registry(ckpt_path) and module_available("litmodels"): + ckpt_path = find_model_local_ckpt_path( + ckpt_path, + default_model_registry=self.trainer._model_registry, + default_root_dir=self.trainer.default_root_dir, + ) + if not ckpt_path: raise ValueError( f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6c3ff4612c68a..8e4e2de97fd6a 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -30,6 +30,7 @@ from weakref import proxy import torch +from lightning_utilities import module_available from torch.optim import Optimizer import lightning.pytorch as pl @@ -70,6 +71,7 @@ from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.model_registry import _is_registry, download_model_from_registry from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.seed import isolate_rng from lightning.pytorch.utilities.types import ( @@ -129,6 +131,7 @@ def __init__( reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, enable_autolog_hparams: bool = True, + model_registry: Optional[str] = None, ) -> None: r"""Customize every aspect of training via flags. @@ -294,6 +297,8 @@ def __init__( enable_autolog_hparams: Whether to log hyperparameters at the start of a run. Default: ``True``. + model_registry: The name of the model being uploaded to Model hub. + Raises: TypeError: If ``gradient_clip_val`` is not an int or float. @@ -308,6 +313,9 @@ def __init__( if default_root_dir is not None: default_root_dir = os.fspath(default_root_dir) + # remove version if accidentally passed + self._model_registry = model_registry.split(":")[0] if model_registry else None + self.barebones = barebones if barebones: # opt-outs @@ -525,7 +533,20 @@ def fit( the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook. ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special - keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised. + keywords ``"last"``, ``"hpc"`` and ``"registry"``. + Otherwise, if there is no checkpoint file at the path, an exception is raised. + + - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded + - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded + - registry: the model will be downloaded from the Lightning Model Registry with following notations: + + - ``'registry'``: uses the latest/default version of default model set + with ``Tainer(..., model_registry="my-model")`` + - ``'registry:model-name'``: uses the latest/default version of this model `model-name` + - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name` + - ``'registry:version:v2'``: uses the default model set + with ``Tainer(..., model_registry="my-model")`` and version 'v2' + Raises: TypeError: @@ -573,6 +594,8 @@ def _fit_impl( ) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, @@ -602,8 +625,8 @@ def validate( Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. - ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate. - If ``None`` and the model instance was passed, use the current weights. + ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish + to validate. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. @@ -681,6 +704,8 @@ def _validate_impl( self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) @@ -711,8 +736,8 @@ def test( Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. - ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test. - If ``None`` and the model instance was passed, use the current weights. + ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish + to test. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. @@ -790,6 +815,8 @@ def _test_impl( self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) @@ -826,8 +853,8 @@ def predict( return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). - ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict. - If ``None`` and the model instance was passed, use the current weights. + ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish + to predict. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. @@ -899,6 +926,8 @@ def _predict_impl( self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) diff --git a/src/lightning/pytorch/utilities/model_registry.py b/src/lightning/pytorch/utilities/model_registry.py new file mode 100644 index 0000000000000..a9ed495eb37d8 --- /dev/null +++ b/src/lightning/pytorch/utilities/model_registry.py @@ -0,0 +1,178 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import Optional + +from lightning_utilities import module_available + +import lightning.pytorch as pl +from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.types import _PATH + +# skip these test on Windows as the path notation differ +if _IS_WINDOWS: + __doctest_skip__ = ["_determine_model_folder"] + + +def _is_registry(text: Optional[_PATH]) -> bool: + """Check if a string equals 'registry' or starts with 'registry:'. + + Args: + text: The string to check + + >>> _is_registry("registry") + True + >>> _is_registry("REGISTRY:model-name") + True + >>> _is_registry("something_registry") + False + >>> _is_registry("") + False + + """ + if not isinstance(text, str): + return False + + # Pattern matches exactly 'registry' or 'registry:' followed by any characters + pattern = r"^registry(:.*|$)" + return bool(re.match(pattern, text.lower())) + + +def _parse_registry_model_version(ckpt_path: Optional[_PATH]) -> tuple[str, str]: + """Parse the model version from a registry path. + + Args: + ckpt_path: The checkpoint path + + Returns: + string name and version of the model + + >>> _parse_registry_model_version("registry:model-name:version:1.0") + ('model-name', '1.0') + >>> _parse_registry_model_version("registry:model-name") + ('model-name', '') + >>> _parse_registry_model_version("registry:version:v2") + ('', 'v2') + + """ + if not ckpt_path or not _is_registry(ckpt_path): + raise ValueError(f"Invalid registry path: {ckpt_path}") + + # Split the path by ':' + parts = str(ckpt_path).lower().split(":") + # Default values + model_name, version = "", "" + + # Extract the model name and version based on the parts + if len(parts) >= 2 and parts[1] != "version": + model_name = parts[1] + if len(parts) == 3 and parts[1] == "version": + version = parts[2] + elif len(parts) == 4 and parts[2] == "version": + version = parts[3] + + return model_name, version + + +def _determine_model_name(ckpt_path: Optional[_PATH], default_model_registry: Optional[str]) -> str: + """Determine the model name from the checkpoint path. + + Args: + ckpt_path: The checkpoint path + default_model_registry: The default model registry + + Returns: + string name of the model with optional version + + >>> _determine_model_name("registry:model-name:version:1.0", "default-model") + 'model-name:1.0' + >>> _determine_model_name("registry:model-name", "default-model") + 'model-name' + >>> _determine_model_name("registry:version:v2", "default-model") + 'default-model:v2' + + """ + # try to find model and version + model_name, model_version = _parse_registry_model_version(ckpt_path) + # omitted model name try to use the model registry from Trainer + if not model_name and default_model_registry: + model_name = default_model_registry + if not model_name: + raise ValueError(f"Invalid model registry: '{ckpt_path}'") + model_registry = model_name + model_registry += f":{model_version}" if model_version else "" + return model_registry + + +def _determine_model_folder(model_name: str, default_root_dir: str) -> str: + """Determine the local model folder based on the model registry. + + Args: + model_name: The model name + default_root_dir: The default root directory + + Returns: + string path to the local model folder + + >>> _determine_model_folder("model-name", "/path/to/root") + '/path/to/root/model-name' + >>> _determine_model_folder("model-name:1.0", "/path/to/root") + '/path/to/root/model-name_1.0' + + """ + if not model_name: + raise ValueError(f"Invalid model registry: '{model_name}'") + # download the latest checkpoint from the model registry + model_name = model_name.replace("/", "_") + model_name = model_name.replace(":", "_") + local_model_dir = os.path.join(default_root_dir, model_name) + return local_model_dir + + +def find_model_local_ckpt_path( + ckpt_path: Optional[_PATH], default_model_registry: Optional[str], default_root_dir: str +) -> str: + """Find the local checkpoint path for a model.""" + model_registry = _determine_model_name(ckpt_path, default_model_registry) + local_model_dir = _determine_model_folder(model_registry, default_root_dir) + + # todo: resolve if there are multiple checkpoints + folder_files = [fn for fn in os.listdir(local_model_dir) if fn.endswith(".ckpt")] + if not folder_files: + raise RuntimeError(f"Parsing files from downloaded model: {model_registry}") + # print(f"local RANK {self.trainer.local_rank}: using model files: {folder_files}") + return os.path.join(local_model_dir, folder_files[0]) + + +def download_model_from_registry(ckpt_path: Optional[_PATH], trainer: "pl.Trainer") -> None: + """Download a model from the Lightning Model Registry.""" + if trainer.local_rank == 0: + if not module_available("litmodels"): + raise ImportError( + "The `litmodels` package is not installed. Please install it with `pip install litmodels`." + ) + + from litmodels import download_model + + model_registry = _determine_model_name(ckpt_path, trainer._model_registry) + local_model_dir = _determine_model_folder(model_registry, trainer.default_root_dir) + + # print(f"Rank {self.trainer.local_rank} downloads model checkpoint '{model_registry}'") + model_files = download_model(model_registry, download_dir=local_model_dir) + # print(f"Model checkpoint '{model_registry}' was downloaded to '{local_model_dir}'") + if not model_files: + raise RuntimeError(f"Download model failed - {model_registry}") + + trainer.strategy.barrier("download_model_from_registry") diff --git a/src/version.info b/src/version.info index 797b505d19610..ee06cd3353a48 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.5.0.post0 +2.5.1rc2