Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2fc8de3
simple plug
Borda Feb 27, 2025
60e3eb8
2.5.1rc0
Borda Feb 27, 2025
b6a490d
model_registry
Borda Mar 6, 2025
eabb3dc
if ckpt_path is True
Borda Mar 7, 2025
557aa9c
2.5.1rc1
Borda Mar 8, 2025
6955b50
registry:version
Borda Mar 8, 2025
dedff89
_parse_registry_model_version
Borda Mar 8, 2025
e04c934
2.5.1rc2
Borda Mar 9, 2025
e577fe7
fix parse version
Borda Mar 12, 2025
8c0d809
LightningModelCheckpoint
Borda Mar 12, 2025
87bc02f
import lightning/pytorch_lightning
Borda Mar 12, 2025
6079bde
print
Borda Mar 14, 2025
b5c82ab
print
Borda Mar 14, 2025
7a98958
self.trainer.strategy.barrier
Borda Mar 14, 2025
facd318
local_model_dir
Borda Mar 14, 2025
c5caed8
cleaning
Borda Mar 15, 2025
05ff4d1
print
Borda Mar 15, 2025
72c9af4
Revert "cleaning"
Borda Mar 15, 2025
0301490
rank_zero_info
Borda Mar 15, 2025
380c3b3
print
Borda Mar 15, 2025
5159123
print
Borda Mar 15, 2025
6ed4826
print
Borda Mar 15, 2025
ddc01d9
#print
Borda Mar 15, 2025
926c61b
refactor
Borda Mar 16, 2025
cdfd320
self.local_rank == 0
Borda Mar 16, 2025
468453d
trainer.strategy.barrier("download_model_from_registry")
Borda Mar 16, 2025
41b9bcc
Merge branch 'master' into add/litmodels
Borda Mar 16, 2025
c5896a6
docs
Borda Mar 16, 2025
4133bba
_fit_impl
Borda Mar 16, 2025
1c149bd
predict validate tests
Borda Mar 16, 2025
3e3a588
_download_model_registry
Borda Mar 17, 2025
5e19d07
Merge branch 'master' into add/litmodels
Borda Mar 18, 2025
99c7042
doctest
Borda Mar 18, 2025
4d846be
doctest
Borda Mar 18, 2025
81bd9ba
__doctest_skip__
Borda Mar 18, 2025
b223878
typing
Borda Mar 18, 2025
e90ae44
typing
Borda Mar 18, 2025
0c17cd7
typing
Borda Mar 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ docs-fabric: clean sphinx-theme
cd docs/source-fabric && $(MAKE) html --jobs $(nproc)

docs-pytorch: clean sphinx-theme
pip install -e .[all] --quiet -r requirements/pytorch/docs.txt -r _notebooks/.actions/requires.txt
pip install -e .[all] --quiet -r requirements/pytorch/docs.txt
cd docs/source-pytorch && $(MAKE) html --jobs $(nproc)

update:
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ nbformat # used for generate empty notebook
ipython[notebook] <8.7.0
setuptools<58.0 # workaround for `error in ipython setup command: use_2to3 is invalid.`

-r ../../_notebooks/.actions/requires.txt
#-r ../../_notebooks/.actions/requires.txt
1 change: 0 additions & 1 deletion src/lightning/pytorch/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""

from lightning.fabric.utilities.seed import seed_everything
from lightning.pytorch.trainer.trainer import Trainer
Expand Down
23 changes: 21 additions & 2 deletions src/lightning/pytorch/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
import os
from collections.abc import Sequence
from datetime import timedelta
from typing import Optional, Union

from lightning_utilities import module_available

import lightning.pytorch as pl
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.pytorch.callbacks import (
Expand Down Expand Up @@ -91,7 +93,24 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
" but found `ModelCheckpoint` in callbacks list."
)
elif enable_checkpointing:
self.trainer.callbacks.append(ModelCheckpoint())
if module_available("litmodels") and self.trainer._model_registry:
trainer_source = inspect.getmodule(self.trainer)
if trainer_source is None or not isinstance(trainer_source.__package__, str):
raise RuntimeError("Unable to determine the source of the trainer.")
# this need to imported based on the actual package lightning/pytorch_lightning
if "pytorch_lightning" in trainer_source.__package__:
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
else:
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint

model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry)
else:
rank_zero_info(
"You are using the plain ModelCheckpoint callback."
" Consider using LitModelCheckpoint which with seamless uploading to Model registry."
)
model_checkpoint = ModelCheckpoint()
self.trainer.callbacks.append(model_checkpoint)

def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
if not enable_model_summary:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem
from lightning_utilities import module_available
from torch import Tensor

import lightning.pytorch as pl
Expand All @@ -33,6 +34,10 @@
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.migration import pl_legacy_patch
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
from lightning.pytorch.utilities.model_registry import (
_is_registry,
find_model_local_ckpt_path,
)
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

log = logging.getLogger(__name__)
Expand All @@ -48,8 +53,7 @@ def __init__(self, trainer: "pl.Trainer") -> None:

@property
def _hpc_resume_path(self) -> Optional[str]:
dir_path_hpc = self.trainer.default_root_dir
dir_path_hpc = str(dir_path_hpc)
dir_path_hpc = str(self.trainer.default_root_dir)
fs, path = url_to_fs(dir_path_hpc)
if not _is_dir(fs, path):
return None
Expand Down Expand Up @@ -194,10 +198,17 @@ def _parse_ckpt_path(
if not self._hpc_resume_path:
raise ValueError(
f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.'
" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
)
ckpt_path = self._hpc_resume_path

elif _is_registry(ckpt_path) and module_available("litmodels"):
ckpt_path = find_model_local_ckpt_path(
ckpt_path,
default_model_registry=self.trainer._model_registry,
default_root_dir=self.trainer.default_root_dir,
)

if not ckpt_path:
raise ValueError(
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
Expand Down
43 changes: 36 additions & 7 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from weakref import proxy

import torch
from lightning_utilities import module_available
from torch.optim import Optimizer

import lightning.pytorch as pl
Expand Down Expand Up @@ -70,6 +71,7 @@
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.model_registry import _is_registry, download_model_from_registry
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.seed import isolate_rng
from lightning.pytorch.utilities.types import (
Expand Down Expand Up @@ -129,6 +131,7 @@ def __init__(
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[_PATH] = None,
enable_autolog_hparams: bool = True,
model_registry: Optional[str] = None,
) -> None:
r"""Customize every aspect of training via flags.

Expand Down Expand Up @@ -294,6 +297,8 @@ def __init__(
enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
Default: ``True``.

model_registry: The name of the model being uploaded to Model hub.

Raises:
TypeError:
If ``gradient_clip_val`` is not an int or float.
Expand All @@ -308,6 +313,9 @@ def __init__(
if default_root_dir is not None:
default_root_dir = os.fspath(default_root_dir)

# remove version if accidentally passed
self._model_registry = model_registry.split(":")[0] if model_registry else None

self.barebones = barebones
if barebones:
# opt-outs
Expand Down Expand Up @@ -525,7 +533,20 @@ def fit(
the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.

ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
keywords ``"last"``, ``"hpc"`` and ``"registry"``.
Otherwise, if there is no checkpoint file at the path, an exception is raised.

- best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
- last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
- registry: the model will be downloaded from the Lightning Model Registry with following notations:

- ``'registry'``: uses the latest/default version of default model set
with ``Tainer(..., model_registry="my-model")``
- ``'registry:model-name'``: uses the latest/default version of this model `model-name`
- ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
- ``'registry:version:v2'``: uses the default model set
with ``Tainer(..., model_registry="my-model")`` and version 'v2'


Raises:
TypeError:
Expand Down Expand Up @@ -573,6 +594,8 @@ def _fit_impl(
)

assert self.state.fn is not None
if _is_registry(ckpt_path) and module_available("litmodels"):
download_model_from_registry(ckpt_path, self)
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn,
ckpt_path,
Expand Down Expand Up @@ -602,8 +625,8 @@ def validate(
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.

ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
If ``None`` and the model instance was passed, use the current weights.
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
to validate. If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.

Expand Down Expand Up @@ -681,6 +704,8 @@ def _validate_impl(
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

assert self.state.fn is not None
if _is_registry(ckpt_path) and module_available("litmodels"):
download_model_from_registry(ckpt_path, self)
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
Expand Down Expand Up @@ -711,8 +736,8 @@ def test(
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.

ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
If ``None`` and the model instance was passed, use the current weights.
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
to test. If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.

Expand Down Expand Up @@ -790,6 +815,8 @@ def _test_impl(
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

assert self.state.fn is not None
if _is_registry(ckpt_path) and module_available("litmodels"):
download_model_from_registry(ckpt_path, self)
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
Expand Down Expand Up @@ -826,8 +853,8 @@ def predict(
return_predictions: Whether to return predictions.
``True`` by default except when an accelerator that spawns processes is used (not supported).

ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict.
If ``None`` and the model instance was passed, use the current weights.
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish
to predict. If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.

Expand Down Expand Up @@ -899,6 +926,8 @@ def _predict_impl(
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

assert self.state.fn is not None
if _is_registry(ckpt_path) and module_available("litmodels"):
download_model_from_registry(ckpt_path, self)
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
Expand Down
Loading
Loading