diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index d782721..a1cbc90 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -1,4 +1,4 @@ -from typing import Any, Type, TypeVar, cast +from typing import TYPE_CHECKING, Any from lightning_sdk.lightning_cloud.login import Auth from lightning_utilities.core.rank_zero import rank_zero_only @@ -7,63 +7,78 @@ from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE if _LIGHTNING_AVAILABLE: - from lightning.pytorch import Trainer - from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint + from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint + + if TYPE_CHECKING: + from lightning.pytorch import Trainer + + if _PYTORCHLIGHTNING_AVAILABLE: - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint + from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint + + if TYPE_CHECKING: + from pytorch_lightning import Trainer + +# Base class to be inherited +class LitModelCheckpointMixin: + """Mixin class for LitModel checkpoint functionality.""" -# Type variable for the ModelCheckpoint class -ModelCheckpointType = TypeVar("ModelCheckpointType") + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: + """Initialize with model name.""" + self.model_name = model_name + try: # authenticate before anything else starts + auth = Auth() + auth.authenticate() + except Exception: + raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") -def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]: - """Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class. + @rank_zero_only + def _upload_model(self, filepath: str) -> None: + # todo: uploading on background so training does nt stops + # todo: use filename as version but need to validate that such version does not exists yet + upload_model(name=self.model_name, model=filepath) - Args: - checkpoint_cls: The ModelCheckpoint class to extend - Returns: - A LitModelCheckpoint class extending the given ModelCheckpoint class - """ +# Create specific implementations +if _LIGHTNING_AVAILABLE: - class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore + class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint): """Lightning ModelCheckpoint with LitModel support. Args: - model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' - where entity is either your username or the name of an organization you are part of. + model_name: Name of the model to upload in format 'organization/teamspace/modelname' args: Additional arguments to pass to the parent class. kwargs: Additional keyword arguments to pass to the parent class. """ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: - """Initialize the LitModelCheckpoint.""" - super().__init__(*args, **kwargs) - self.model_name = model_name - - try: # authenticate before anything else starts - auth = Auth() - auth.authenticate() - except Exception: - raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") - - @rank_zero_only - def _upload_model(self, filepath: str) -> None: - # todo: uploading on background so training does nt stops - # todo: use filename as version but need to validate that such version does not exists yet - upload_model(name=self.model_name, model=filepath) - - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: + """Initialize the checkpoint with model name and other parameters.""" + _LightningModelCheckpoint.__init__(self, *args, **kwargs) + LitModelCheckpointMixin.__init__(self, model_name) + + def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: super()._save_checkpoint(trainer, filepath) self._upload_model(filepath) - return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) - -# Create explicit classes with specific names -if _LIGHTNING_AVAILABLE: - LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint) if _PYTORCHLIGHTNING_AVAILABLE: - PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint) + + class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint): + """PyTorch Lightning ModelCheckpoint with LitModel support. + + Args: + model_name: Name of the model to upload in format 'organization/teamspace/modelname' + args: Additional arguments to pass to the parent class. + kwargs: Additional keyword arguments to pass to the parent class. + """ + + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: + """Initialize the checkpoint with model name and other parameters.""" + _PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs) + LitModelCheckpointMixin.__init__(self, model_name) + + def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + self._upload_model(filepath) diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py new file mode 100644 index 0000000..c87963b --- /dev/null +++ b/tests/integrations/__init__.py @@ -0,0 +1,7 @@ +import pytest +from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE + +_SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available") +_SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif( + not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available" +) diff --git a/tests/integrations/test_checkpoints.py b/tests/integrations/test_checkpoints.py index 42249ce..5a3a100 100644 --- a/tests/integrations/test_checkpoints.py +++ b/tests/integrations/test_checkpoints.py @@ -1,18 +1,17 @@ +import pickle import re from unittest import mock import pytest -from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE + +from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING @pytest.mark.parametrize( "importing", [ - pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")), - pytest.param( - "pytorch_lightning", - marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"), - ), + pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING), + pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING), ], ) @mock.patch("litmodels.io.cloud.sdk_upload_model") @@ -51,3 +50,21 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, for call_args in mock_upload_model.call_args_list: path = call_args[1]["path"] assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path) + + +@pytest.mark.parametrize( + "importing", + [ + pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING), + pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING), + ], +) +@mock.patch("litmodels.integrations.checkpoints.Auth") +def test_lightning_checkpointing_pickleable(mock_auth, importing): + if importing == "lightning": + from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint + elif importing == "pytorch_lightning": + from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint + + ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name") + pickle.dumps(ckpt)