Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/litmodels/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
"""Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others."""

from litmodels.integrations.checkpoints import LitModelCheckpoint
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE

__all__ = ["LitModelCheckpoint"]
__all__ = []

if _LIGHTNING_AVAILABLE:
from litmodels.integrations.checkpoints import LightningModelCheckpoint

__all__ += ["LightningModelCheckpoint"]

if _PYTORCHLIGHTNING_AVAILABLE:
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint

__all__ += ["PTLightningModelCheckpoint"]
77 changes: 48 additions & 29 deletions src/litmodels/integrations/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Type, TypeVar, cast

from lightning_sdk.lightning_cloud.login import Auth

Expand All @@ -7,39 +7,58 @@

if _LIGHTNING_AVAILABLE:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
elif _PYTORCHLIGHTNING_AVAILABLE:
from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint
if _PYTORCHLIGHTNING_AVAILABLE:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
else:
raise ModuleNotFoundError("No module named 'lightning' or 'pytorch_lightning'")
from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint


class LitModelCheckpoint(ModelCheckpoint):
"""Lightning ModelCheckpoint with LitModel support.
# Type variable for the ModelCheckpoint class
ModelCheckpointType = TypeVar("ModelCheckpointType")


def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]:
"""Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class.

Args:
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
checkpoint_cls: The ModelCheckpoint class to extend

Returns:
A LitModelCheckpoint class extending the given ModelCheckpoint class
"""

def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
"""Initialize the LitModelCheckpoint."""
super().__init__(*args, **kwargs)
self.model_name = model_name

try:
# authenticate before anything else starts
auth = Auth()
auth.authenticate()
except Exception:
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")

def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
# todo: uploading on background so training does nt stops
# todo: use filename as version but need to validate that such version does not exists yet
upload_model(name=self.model_name, model=filepath)
class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore
"""Lightning ModelCheckpoint with LitModel support.

Args:
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
"""

def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
"""Initialize the LitModelCheckpoint."""
super().__init__(*args, **kwargs)
self.model_name = model_name

try: # authenticate before anything else starts
auth = Auth()
auth.authenticate()
except Exception:
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")

def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
# todo: uploading on background so training does nt stops
# todo: use filename as version but need to validate that such version does not exists yet
upload_model(name=self.model_name, model=filepath)

return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate)


# Create explicit classes with specific names
if _LIGHTNING_AVAILABLE:
LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint)
if _PYTORCHLIGHTNING_AVAILABLE:
PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint)
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
import re
from unittest import mock

from litmodels.integrations.checkpoints import LitModelCheckpoint
import pytest
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE

if _LIGHTNING_AVAILABLE:
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
elif _PYTORCHLIGHTNING_AVAILABLE:
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel


@pytest.mark.parametrize(
"importing",
[
pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")),
pytest.param(
"pytorch_lightning",
marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"),
),
],
)
@mock.patch("litmodels.io.cloud.sdk_upload_model")
@mock.patch("litmodels.integrations.checkpoints.Auth")
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, tmp_path):
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, tmp_path):
if importing == "lightning":
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
elif importing == "pytorch_lightning":
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel

# Validate inheritance
assert issubclass(LitModelCheckpoint, ModelCheckpoint)

mock_upload_model.return_value.name = "org-name/teamspace/model-name"

trainer = Trainer(
Expand Down
Loading