diff --git a/src/litmodels/integrations/__init__.py b/src/litmodels/integrations/__init__.py index b8a1b71..c0bbf06 100644 --- a/src/litmodels/integrations/__init__.py +++ b/src/litmodels/integrations/__init__.py @@ -1,5 +1,15 @@ """Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others.""" -from litmodels.integrations.checkpoints import LitModelCheckpoint +from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE -__all__ = ["LitModelCheckpoint"] +__all__ = [] + +if _LIGHTNING_AVAILABLE: + from litmodels.integrations.checkpoints import LightningModelCheckpoint + + __all__ += ["LightningModelCheckpoint"] + +if _PYTORCHLIGHTNING_AVAILABLE: + from litmodels.integrations.checkpoints import PTLightningModelCheckpoint + + __all__ += ["PTLightningModelCheckpoint"] diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 016f2d8..7d76852 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Type, TypeVar, cast from lightning_sdk.lightning_cloud.login import Auth @@ -7,39 +7,58 @@ if _LIGHTNING_AVAILABLE: from lightning.pytorch import Trainer - from lightning.pytorch.callbacks import ModelCheckpoint -elif _PYTORCHLIGHTNING_AVAILABLE: + from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint +if _PYTORCHLIGHTNING_AVAILABLE: from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint -else: - raise ModuleNotFoundError("No module named 'lightning' or 'pytorch_lightning'") + from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint -class LitModelCheckpoint(ModelCheckpoint): - """Lightning ModelCheckpoint with LitModel support. +# Type variable for the ModelCheckpoint class +ModelCheckpointType = TypeVar("ModelCheckpointType") + + +def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]: + """Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class. Args: - model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' - where entity is either your username or the name of an organization you are part of. - args: Additional arguments to pass to the parent class. - kwargs: Additional keyword arguments to pass to the parent class. + checkpoint_cls: The ModelCheckpoint class to extend + Returns: + A LitModelCheckpoint class extending the given ModelCheckpoint class """ - def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: - """Initialize the LitModelCheckpoint.""" - super().__init__(*args, **kwargs) - self.model_name = model_name - - try: - # authenticate before anything else starts - auth = Auth() - auth.authenticate() - except Exception: - raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") - - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) - # todo: uploading on background so training does nt stops - # todo: use filename as version but need to validate that such version does not exists yet - upload_model(name=self.model_name, model=filepath) + class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore + """Lightning ModelCheckpoint with LitModel support. + + Args: + model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' + where entity is either your username or the name of an organization you are part of. + args: Additional arguments to pass to the parent class. + kwargs: Additional keyword arguments to pass to the parent class. + """ + + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: + """Initialize the LitModelCheckpoint.""" + super().__init__(*args, **kwargs) + self.model_name = model_name + + try: # authenticate before anything else starts + auth = Auth() + auth.authenticate() + except Exception: + raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") + + def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + # todo: uploading on background so training does nt stops + # todo: use filename as version but need to validate that such version does not exists yet + upload_model(name=self.model_name, model=filepath) + + return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) + + +# Create explicit classes with specific names +if _LIGHTNING_AVAILABLE: + LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint) +if _PYTORCHLIGHTNING_AVAILABLE: + PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_checkpoints.py similarity index 51% rename from tests/integrations/test_lightning.py rename to tests/integrations/test_checkpoints.py index 11536c3..42249ce 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_checkpoints.py @@ -1,20 +1,37 @@ import re from unittest import mock -from litmodels.integrations.checkpoints import LitModelCheckpoint +import pytest from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE -if _LIGHTNING_AVAILABLE: - from lightning import Trainer - from lightning.pytorch.demos.boring_classes import BoringModel -elif _PYTORCHLIGHTNING_AVAILABLE: - from pytorch_lightning import Trainer - from pytorch_lightning.demos.boring_classes import BoringModel - +@pytest.mark.parametrize( + "importing", + [ + pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")), + pytest.param( + "pytorch_lightning", + marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"), + ), + ], +) @mock.patch("litmodels.io.cloud.sdk_upload_model") @mock.patch("litmodels.integrations.checkpoints.Auth") -def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, tmp_path): +def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, tmp_path): + if importing == "lightning": + from lightning import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint + from lightning.pytorch.demos.boring_classes import BoringModel + from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint + elif importing == "pytorch_lightning": + from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + from pytorch_lightning.demos.boring_classes import BoringModel + + # Validate inheritance + assert issubclass(LitModelCheckpoint, ModelCheckpoint) + mock_upload_model.return_value.name = "org-name/teamspace/model-name" trainer = Trainer(