Skip to content

Commit bd7681c

Browse files
authored
ckpt: pickleable (#54)
1 parent 41c4cad commit bd7681c

File tree

3 files changed

+85
-46
lines changed

3 files changed

+85
-46
lines changed
Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Type, TypeVar, cast
1+
from typing import TYPE_CHECKING, Any
22

33
from lightning_sdk.lightning_cloud.login import Auth
44
from lightning_utilities.core.rank_zero import rank_zero_only
@@ -7,63 +7,78 @@
77
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
88

99
if _LIGHTNING_AVAILABLE:
10-
from lightning.pytorch import Trainer
11-
from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint
10+
from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint
11+
12+
if TYPE_CHECKING:
13+
from lightning.pytorch import Trainer
14+
15+
1216
if _PYTORCHLIGHTNING_AVAILABLE:
13-
from pytorch_lightning import Trainer
14-
from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint
17+
from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint
18+
19+
if TYPE_CHECKING:
20+
from pytorch_lightning import Trainer
21+
1522

23+
# Base class to be inherited
24+
class LitModelCheckpointMixin:
25+
"""Mixin class for LitModel checkpoint functionality."""
1626

17-
# Type variable for the ModelCheckpoint class
18-
ModelCheckpointType = TypeVar("ModelCheckpointType")
27+
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
28+
"""Initialize with model name."""
29+
self.model_name = model_name
1930

31+
try: # authenticate before anything else starts
32+
auth = Auth()
33+
auth.authenticate()
34+
except Exception:
35+
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
2036

21-
def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]:
22-
"""Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class.
37+
@rank_zero_only
38+
def _upload_model(self, filepath: str) -> None:
39+
# todo: uploading on background so training does nt stops
40+
# todo: use filename as version but need to validate that such version does not exists yet
41+
upload_model(name=self.model_name, model=filepath)
2342

24-
Args:
25-
checkpoint_cls: The ModelCheckpoint class to extend
2643

27-
Returns:
28-
A LitModelCheckpoint class extending the given ModelCheckpoint class
29-
"""
44+
# Create specific implementations
45+
if _LIGHTNING_AVAILABLE:
3046

31-
class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore
47+
class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint):
3248
"""Lightning ModelCheckpoint with LitModel support.
3349
3450
Args:
35-
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
36-
where entity is either your username or the name of an organization you are part of.
51+
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
3752
args: Additional arguments to pass to the parent class.
3853
kwargs: Additional keyword arguments to pass to the parent class.
3954
"""
4055

4156
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
42-
"""Initialize the LitModelCheckpoint."""
43-
super().__init__(*args, **kwargs)
44-
self.model_name = model_name
45-
46-
try: # authenticate before anything else starts
47-
auth = Auth()
48-
auth.authenticate()
49-
except Exception:
50-
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
51-
52-
@rank_zero_only
53-
def _upload_model(self, filepath: str) -> None:
54-
# todo: uploading on background so training does nt stops
55-
# todo: use filename as version but need to validate that such version does not exists yet
56-
upload_model(name=self.model_name, model=filepath)
57-
58-
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
57+
"""Initialize the checkpoint with model name and other parameters."""
58+
_LightningModelCheckpoint.__init__(self, *args, **kwargs)
59+
LitModelCheckpointMixin.__init__(self, model_name)
60+
61+
def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
5962
super()._save_checkpoint(trainer, filepath)
6063
self._upload_model(filepath)
6164

62-
return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate)
63-
6465

65-
# Create explicit classes with specific names
66-
if _LIGHTNING_AVAILABLE:
67-
LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint)
6866
if _PYTORCHLIGHTNING_AVAILABLE:
69-
PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint)
67+
68+
class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
69+
"""PyTorch Lightning ModelCheckpoint with LitModel support.
70+
71+
Args:
72+
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
73+
args: Additional arguments to pass to the parent class.
74+
kwargs: Additional keyword arguments to pass to the parent class.
75+
"""
76+
77+
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
78+
"""Initialize the checkpoint with model name and other parameters."""
79+
_PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
80+
LitModelCheckpointMixin.__init__(self, model_name)
81+
82+
def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
83+
super()._save_checkpoint(trainer, filepath)
84+
self._upload_model(filepath)

tests/integrations/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
3+
4+
_SKIP_IF_LIGHTNING_MISSING = pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")
5+
_SKIP_IF_PYTORCHLIGHTNING_MISSING = pytest.mark.skipif(
6+
not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"
7+
)

tests/integrations/test_checkpoints.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1+
import pickle
12
import re
23
from unittest import mock
34

45
import pytest
5-
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
6+
7+
from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING
68

79

810
@pytest.mark.parametrize(
911
"importing",
1012
[
11-
pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")),
12-
pytest.param(
13-
"pytorch_lightning",
14-
marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"),
15-
),
13+
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING),
14+
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
1615
],
1716
)
1817
@mock.patch("litmodels.io.cloud.sdk_upload_model")
@@ -51,3 +50,21 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing,
5150
for call_args in mock_upload_model.call_args_list:
5251
path = call_args[1]["path"]
5352
assert re.match(r".*[/\\]lightning_logs[/\\]version_\d+[/\\]checkpoints[/\\]epoch=\d+-step=\d+\.ckpt$", path)
53+
54+
55+
@pytest.mark.parametrize(
56+
"importing",
57+
[
58+
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_MISSING),
59+
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
60+
],
61+
)
62+
@mock.patch("litmodels.integrations.checkpoints.Auth")
63+
def test_lightning_checkpointing_pickleable(mock_auth, importing):
64+
if importing == "lightning":
65+
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
66+
elif importing == "pytorch_lightning":
67+
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
68+
69+
ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
70+
pickle.dumps(ckpt)

0 commit comments

Comments
 (0)