Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 55 additions & 40 deletions src/litmodels/integrations/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any

from lightning_sdk.lightning_cloud.login import Auth
from lightning_utilities.core.rank_zero import rank_zero_only
Expand All @@ -7,63 +7,78 @@
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE

if _LIGHTNING_AVAILABLE:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint

if TYPE_CHECKING:
from lightning.pytorch import Trainer


if _PYTORCHLIGHTNING_AVAILABLE:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint

if TYPE_CHECKING:
from pytorch_lightning import Trainer


# Base class to be inherited
class LitModelCheckpointMixin:
"""Mixin class for LitModel checkpoint functionality."""

# Type variable for the ModelCheckpoint class
ModelCheckpointType = TypeVar("ModelCheckpointType")
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
"""Initialize with model name."""
self.model_name = model_name

try: # authenticate before anything else starts
auth = Auth()
auth.authenticate()
except Exception:
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")

def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]:
"""Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class.
@rank_zero_only
def _upload_model(self, filepath: str) -> None:
# todo: uploading on background so training does nt stops
# todo: use filename as version but need to validate that such version does not exists yet
upload_model(name=self.model_name, model=filepath)

Args:
checkpoint_cls: The ModelCheckpoint class to extend

Returns:
A LitModelCheckpoint class extending the given ModelCheckpoint class
"""
# Create specific implementations
if _LIGHTNING_AVAILABLE:

class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore
class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint):
"""Lightning ModelCheckpoint with LitModel support.

Args:
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
"""

def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
"""Initialize the LitModelCheckpoint."""
super().__init__(*args, **kwargs)
self.model_name = model_name

try: # authenticate before anything else starts
auth = Auth()
auth.authenticate()
except Exception:
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")

@rank_zero_only
def _upload_model(self, filepath: str) -> None:
# todo: uploading on background so training does nt stops
# todo: use filename as version but need to validate that such version does not exists yet
upload_model(name=self.model_name, model=filepath)

def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
"""Initialize the checkpoint with model name and other parameters."""
_LightningModelCheckpoint.__init__(self, *args, **kwargs)
LitModelCheckpointMixin.__init__(self, model_name)

def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
self._upload_model(filepath)

return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate)


# Create explicit classes with specific names
if _LIGHTNING_AVAILABLE:
LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint)
if _PYTORCHLIGHTNING_AVAILABLE:
PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint)

class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
"""PyTorch Lightning ModelCheckpoint with LitModel support.

Args:
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
"""

def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
"""Initialize the checkpoint with model name and other parameters."""
_PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
LitModelCheckpointMixin.__init__(self, model_name)

def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
self._upload_model(filepath)
7 changes: 7 additions & 0 deletions tests/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE

_SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")
_SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif(
not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"
)
29 changes: 23 additions & 6 deletions tests/integrations/test_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import pickle
import re
from unittest import mock

import pytest
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE

from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING


@pytest.mark.parametrize(
"importing",
[
pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")),
pytest.param(
"pytorch_lightning",
marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"),
),
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING),
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
],
)
@mock.patch("litmodels.io.cloud.sdk_upload_model")
Expand Down Expand Up @@ -51,3 +50,21 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing,
for call_args in mock_upload_model.call_args_list:
path = call_args[1]["path"]
assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path)


@pytest.mark.parametrize(
"importing",
[
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING),
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
],
)
@mock.patch("litmodels.integrations.checkpoints.Auth")
def test_lightning_checkpointing_pickleable(mock_auth, importing):
if importing == "lightning":
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
elif importing == "pytorch_lightning":
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint

ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
pickle.dumps(ckpt)
Loading