From b66a37f71275cffd3ffb40bc4f0ab31aa0e703ff Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 10 Mar 2025 21:07:36 +0100 Subject: [PATCH 1/4] ckpt: both inheritance --- src/litmodels/integrations/__init__.py | 14 +++- src/litmodels/integrations/checkpoints.py | 77 ++++++++++++------- ...{test_lightning.py => test_checkpoints.py} | 35 ++++++--- 3 files changed, 86 insertions(+), 40 deletions(-) rename tests/integrations/{test_lightning.py => test_checkpoints.py} (51%) diff --git a/src/litmodels/integrations/__init__.py b/src/litmodels/integrations/__init__.py index b8a1b71..c0bbf06 100644 --- a/src/litmodels/integrations/__init__.py +++ b/src/litmodels/integrations/__init__.py @@ -1,5 +1,15 @@ """Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others.""" -from litmodels.integrations.checkpoints import LitModelCheckpoint +from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE -__all__ = ["LitModelCheckpoint"] +__all__ = [] + +if _LIGHTNING_AVAILABLE: + from litmodels.integrations.checkpoints import LightningModelCheckpoint + + __all__ += ["LightningModelCheckpoint"] + +if _PYTORCHLIGHTNING_AVAILABLE: + from litmodels.integrations.checkpoints import PTLightningModelCheckpoint + + __all__ += ["PTLightningModelCheckpoint"] diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 016f2d8..1e6d84b 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Type, TypeVar, cast from lightning_sdk.lightning_cloud.login import Auth @@ -7,39 +7,58 @@ if _LIGHTNING_AVAILABLE: from lightning.pytorch import Trainer - from lightning.pytorch.callbacks import ModelCheckpoint -elif _PYTORCHLIGHTNING_AVAILABLE: + from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint +if _PYTORCHLIGHTNING_AVAILABLE: from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint -else: - raise ModuleNotFoundError("No module named 'lightning' or 'pytorch_lightning'") + from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint -class LitModelCheckpoint(ModelCheckpoint): - """Lightning ModelCheckpoint with LitModel support. +# Type variable for the ModelCheckpoint class +ModelCheckpointType = TypeVar("ModelCheckpointType") + + +def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]: + """Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class. Args: - model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' - where entity is either your username or the name of an organization you are part of. - args: Additional arguments to pass to the parent class. - kwargs: Additional keyword arguments to pass to the parent class. + checkpoint_cls: The ModelCheckpoint class to extend + Returns: + A LitModelCheckpoint class extending the given ModelCheckpoint class """ - def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: - """Initialize the LitModelCheckpoint.""" - super().__init__(*args, **kwargs) - self.model_name = model_name - - try: - # authenticate before anything else starts - auth = Auth() - auth.authenticate() - except Exception: - raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") - - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) - # todo: uploading on background so training does nt stops - # todo: use filename as version but need to validate that such version does not exists yet - upload_model(name=self.model_name, model=filepath) + class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore + """Lightning ModelCheckpoint with LitModel support. + + Args: + model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' + where entity is either your username or the name of an organization you are part of. + args: Additional arguments to pass to the parent class. + kwargs: Additional keyword arguments to pass to the parent class. + """ + + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: + """Initialize the LitModelCheckpoint.""" + super().__init__(*args, **kwargs) + self.model_name = model_name + + try: + # authenticate before anything else starts + auth = Auth() + auth.authenticate() + except Exception: + raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") + + def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + # upload model after checkpoint is saved + upload_model(name=self.model_name, model=filepath) + + return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) + + +# Create explicit classes with specific names if needed +if _LIGHTNING_AVAILABLE: + LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint) +if _PYTORCHLIGHTNING_AVAILABLE: + PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_checkpoints.py similarity index 51% rename from tests/integrations/test_lightning.py rename to tests/integrations/test_checkpoints.py index 11536c3..42249ce 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_checkpoints.py @@ -1,20 +1,37 @@ import re from unittest import mock -from litmodels.integrations.checkpoints import LitModelCheckpoint +import pytest from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE -if _LIGHTNING_AVAILABLE: - from lightning import Trainer - from lightning.pytorch.demos.boring_classes import BoringModel -elif _PYTORCHLIGHTNING_AVAILABLE: - from pytorch_lightning import Trainer - from pytorch_lightning.demos.boring_classes import BoringModel - +@pytest.mark.parametrize( + "importing", + [ + pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")), + pytest.param( + "pytorch_lightning", + marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"), + ), + ], +) @mock.patch("litmodels.io.cloud.sdk_upload_model") @mock.patch("litmodels.integrations.checkpoints.Auth") -def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, tmp_path): +def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, tmp_path): + if importing == "lightning": + from lightning import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint + from lightning.pytorch.demos.boring_classes import BoringModel + from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint + elif importing == "pytorch_lightning": + from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + from pytorch_lightning.demos.boring_classes import BoringModel + + # Validate inheritance + assert issubclass(LitModelCheckpoint, ModelCheckpoint) + mock_upload_model.return_value.name = "org-name/teamspace/model-name" trainer = Trainer( From 90e3388c78426f5d4f3eab3c45a1d03290fb4a8a Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 10 Mar 2025 21:12:11 +0100 Subject: [PATCH 2/4] Apply suggestions from code review --- src/litmodels/integrations/checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 1e6d84b..47851d8 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -57,7 +57,7 @@ def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) -# Create explicit classes with specific names if needed +# Create explicit classes with specific names if _LIGHTNING_AVAILABLE: LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint) if _PYTORCHLIGHTNING_AVAILABLE: From 7d060f4ee64768e6f23f1abab96e0a2a9791341c Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 10 Mar 2025 21:16:09 +0100 Subject: [PATCH 3/4] Apply suggestions from code review --- src/litmodels/integrations/checkpoints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 47851d8..80dbd24 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -51,7 +51,8 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: super()._save_checkpoint(trainer, filepath) - # upload model after checkpoint is saved + # todo: uploading on background so training does nt stops + # todo: use filename as version but need to validate that such version does not exists yet upload_model(name=self.model_name, model=filepath) return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) From 3223b137ad706704eb40e042ffc579e4b2819bb8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 10 Mar 2025 21:16:34 +0100 Subject: [PATCH 4/4] Apply suggestions from code review --- src/litmodels/integrations/checkpoints.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 80dbd24..7d76852 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -42,8 +42,7 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.model_name = model_name - try: - # authenticate before anything else starts + try: # authenticate before anything else starts auth = Auth() auth.authenticate() except Exception: