Skip to content

Commit b66a37f

Browse files
committed
ckpt: both inheritance
1 parent a96a450 commit b66a37f

File tree

3 files changed

+86
-40
lines changed

3 files changed

+86
-40
lines changed
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
"""Integrations with training frameworks like PyTorch Lightning, TensorFlow, and others."""
22

3-
from litmodels.integrations.checkpoints import LitModelCheckpoint
3+
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
44

5-
__all__ = ["LitModelCheckpoint"]
5+
__all__ = []
6+
7+
if _LIGHTNING_AVAILABLE:
8+
from litmodels.integrations.checkpoints import LightningModelCheckpoint
9+
10+
__all__ += ["LightningModelCheckpoint"]
11+
12+
if _PYTORCHLIGHTNING_AVAILABLE:
13+
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint
14+
15+
__all__ += ["PTLightningModelCheckpoint"]
Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Type, TypeVar, cast
22

33
from lightning_sdk.lightning_cloud.login import Auth
44

@@ -7,39 +7,58 @@
77

88
if _LIGHTNING_AVAILABLE:
99
from lightning.pytorch import Trainer
10-
from lightning.pytorch.callbacks import ModelCheckpoint
11-
elif _PYTORCHLIGHTNING_AVAILABLE:
10+
from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint
11+
if _PYTORCHLIGHTNING_AVAILABLE:
1212
from pytorch_lightning import Trainer
13-
from pytorch_lightning.callbacks import ModelCheckpoint
14-
else:
15-
raise ModuleNotFoundError("No module named 'lightning' or 'pytorch_lightning'")
13+
from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint
1614

1715

18-
class LitModelCheckpoint(ModelCheckpoint):
19-
"""Lightning ModelCheckpoint with LitModel support.
16+
# Type variable for the ModelCheckpoint class
17+
ModelCheckpointType = TypeVar("ModelCheckpointType")
18+
19+
20+
def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]:
21+
"""Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class.
2022
2123
Args:
22-
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
23-
where entity is either your username or the name of an organization you are part of.
24-
args: Additional arguments to pass to the parent class.
25-
kwargs: Additional keyword arguments to pass to the parent class.
24+
checkpoint_cls: The ModelCheckpoint class to extend
2625
26+
Returns:
27+
A LitModelCheckpoint class extending the given ModelCheckpoint class
2728
"""
2829

29-
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
30-
"""Initialize the LitModelCheckpoint."""
31-
super().__init__(*args, **kwargs)
32-
self.model_name = model_name
33-
34-
try:
35-
# authenticate before anything else starts
36-
auth = Auth()
37-
auth.authenticate()
38-
except Exception:
39-
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
40-
41-
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
42-
super()._save_checkpoint(trainer, filepath)
43-
# todo: uploading on background so training does nt stops
44-
# todo: use filename as version but need to validate that such version does not exists yet
45-
upload_model(name=self.model_name, model=filepath)
30+
class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore
31+
"""Lightning ModelCheckpoint with LitModel support.
32+
33+
Args:
34+
model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
35+
where entity is either your username or the name of an organization you are part of.
36+
args: Additional arguments to pass to the parent class.
37+
kwargs: Additional keyword arguments to pass to the parent class.
38+
"""
39+
40+
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
41+
"""Initialize the LitModelCheckpoint."""
42+
super().__init__(*args, **kwargs)
43+
self.model_name = model_name
44+
45+
try:
46+
# authenticate before anything else starts
47+
auth = Auth()
48+
auth.authenticate()
49+
except Exception:
50+
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
51+
52+
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
53+
super()._save_checkpoint(trainer, filepath)
54+
# upload model after checkpoint is saved
55+
upload_model(name=self.model_name, model=filepath)
56+
57+
return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate)
58+
59+
60+
# Create explicit classes with specific names if needed
61+
if _LIGHTNING_AVAILABLE:
62+
LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint)
63+
if _PYTORCHLIGHTNING_AVAILABLE:
64+
PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint)
Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,37 @@
11
import re
22
from unittest import mock
33

4-
from litmodels.integrations.checkpoints import LitModelCheckpoint
4+
import pytest
55
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
66

7-
if _LIGHTNING_AVAILABLE:
8-
from lightning import Trainer
9-
from lightning.pytorch.demos.boring_classes import BoringModel
10-
elif _PYTORCHLIGHTNING_AVAILABLE:
11-
from pytorch_lightning import Trainer
12-
from pytorch_lightning.demos.boring_classes import BoringModel
13-
147

8+
@pytest.mark.parametrize(
9+
"importing",
10+
[
11+
pytest.param("lightning", marks=pytest.mark.skipif(not _LIGHTNING_AVAILABLE, reason="Lightning not available")),
12+
pytest.param(
13+
"pytorch_lightning",
14+
marks=pytest.mark.skipif(not _PYTORCHLIGHTNING_AVAILABLE, reason="PyTorch Lightning not available"),
15+
),
16+
],
17+
)
1518
@mock.patch("litmodels.io.cloud.sdk_upload_model")
1619
@mock.patch("litmodels.integrations.checkpoints.Auth")
17-
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, tmp_path):
20+
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, tmp_path):
21+
if importing == "lightning":
22+
from lightning import Trainer
23+
from lightning.pytorch.callbacks import ModelCheckpoint
24+
from lightning.pytorch.demos.boring_classes import BoringModel
25+
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
26+
elif importing == "pytorch_lightning":
27+
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
28+
from pytorch_lightning import Trainer
29+
from pytorch_lightning.callbacks import ModelCheckpoint
30+
from pytorch_lightning.demos.boring_classes import BoringModel
31+
32+
# Validate inheritance
33+
assert issubclass(LitModelCheckpoint, ModelCheckpoint)
34+
1835
mock_upload_model.return_value.name = "org-name/teamspace/model-name"
1936

2037
trainer = Trainer(

0 commit comments

Comments
 (0)