Skip to content

Commit dd29912

Browse files
enable empty checkpoint name (#72)
* enable default checkpoint name * tests + mocks * linting --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8b9492a commit dd29912

File tree

4 files changed

+77
-28
lines changed

4 files changed

+77
-28
lines changed

src/litmodels/integrations/checkpoints.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
1-
from typing import TYPE_CHECKING, Any
1+
from abc import ABC
2+
from datetime import datetime
3+
from typing import TYPE_CHECKING, Any, Optional
24

35
from lightning_sdk.lightning_cloud.login import Auth
4-
from lightning_utilities.core.rank_zero import rank_zero_only
6+
from lightning_utilities.core.rank_zero import rank_zero_only, rank_zero_warn
57

68
from litmodels import upload_model
79
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
810

911
if _LIGHTNING_AVAILABLE:
1012
from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint
1113

12-
if TYPE_CHECKING:
13-
from lightning.pytorch import Trainer
14-
1514

1615
if _PYTORCHLIGHTNING_AVAILABLE:
1716
from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint
1817

19-
if TYPE_CHECKING:
20-
from pytorch_lightning import Trainer
18+
19+
if TYPE_CHECKING:
20+
if _LIGHTNING_AVAILABLE:
21+
import lightning.pytorch as pl
22+
if _PYTORCHLIGHTNING_AVAILABLE:
23+
import pytorch_lightning as pl
2124

2225

2326
# Base class to be inherited
24-
class LitModelCheckpointMixin:
27+
class LitModelCheckpointMixin(ABC):
2528
"""Mixin class for LitModel checkpoint functionality."""
2629

27-
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
30+
# mainly ofr mocking reasons
31+
_datetime_stamp: str = datetime.now().strftime("%Y%m%d-%H%M")
32+
model_name: Optional[str] = None
33+
34+
def __init__(self, model_name: Optional[str]) -> None:
2835
"""Initialize with model name."""
36+
if not model_name:
37+
rank_zero_warn(
38+
"The model is not defined so we will continue with LightningModule names and timestamp of now"
39+
)
2940
self.model_name = model_name
3041

3142
try: # authenticate before anything else starts
@@ -38,8 +49,19 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
3849
def _upload_model(self, filepath: str) -> None:
3950
# todo: uploading on background so training does nt stops
4051
# todo: use filename as version but need to validate that such version does not exists yet
52+
if not self.model_name:
53+
raise RuntimeError(
54+
"Model name is not specified neither updated by `setup` method via Trainer."
55+
" Please set the model name before uploading or ensure that `setup` method is called."
56+
)
4157
upload_model(name=self.model_name, model=filepath)
4258

59+
def _update_model_name(self, pl_model: "pl.LightningModule") -> None:
60+
if self.model_name:
61+
return
62+
# setting the model name as Lightning module with some time hash
63+
self.model_name = pl_model.__class__.__name__ + f"_{self._datetime_stamp}"
64+
4365

4466
# Create specific implementations
4567
if _LIGHTNING_AVAILABLE:
@@ -53,15 +75,20 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
5375
kwargs: Additional keyword arguments to pass to the parent class.
5476
"""
5577

56-
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
78+
def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any) -> None:
5779
"""Initialize the checkpoint with model name and other parameters."""
5880
_LightningModelCheckpoint.__init__(self, *args, **kwargs)
5981
LitModelCheckpointMixin.__init__(self, model_name)
6082

61-
def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
83+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
84+
"""Setup the checkpoint callback."""
85+
super().setup(trainer, pl_module, stage)
86+
self._update_model_name(pl_module)
87+
88+
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
89+
"""Extend the save checkpoint method to upload the model."""
6290
super()._save_checkpoint(trainer, filepath)
63-
if trainer.is_global_zero:
64-
# Only upload from the main process
91+
if trainer.is_global_zero: # Only upload from the main process
6592
self._upload_model(filepath)
6693

6794

@@ -76,13 +103,18 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
76103
kwargs: Additional keyword arguments to pass to the parent class.
77104
"""
78105

79-
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
106+
def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any) -> None:
80107
"""Initialize the checkpoint with model name and other parameters."""
81108
_PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
82109
LitModelCheckpointMixin.__init__(self, model_name)
83110

84-
def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
111+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
112+
"""Setup the checkpoint callback."""
113+
super().setup(trainer, pl_module, stage)
114+
self._update_model_name(pl_module)
115+
116+
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
117+
"""Extend the save checkpoint method to upload the model."""
85118
super()._save_checkpoint(trainer, filepath)
86-
if trainer.is_global_zero:
87-
# Only upload from the main process
119+
if trainer.is_global_zero: # Only upload from the main process
88120
self._upload_model(filepath)

tests/integrations/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@
1616
_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 = pytest.mark.skipif(
1717
not _PYTORCHLIGHTNING_GREATER_EQUAL_2_5_1, reason="PyTorch Lightning without integration introduced in 2.5.1"
1818
)
19+
20+
LIT_ORG = "lightning-ai"
21+
LIT_TEAMSPACE = "LitModels"

tests/integrations/test_checkpoints.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_MISSING),
1515
],
1616
)
17+
@pytest.mark.parametrize("with_model_name", [True, False])
18+
@mock.patch("litmodels.integrations.checkpoints.LitModelCheckpointMixin._datetime_stamp", return_value="20250102-1213")
19+
@mock.patch(
20+
"lightning_sdk.models._resolve_teamspace",
21+
return_value=mock.MagicMock(owner=mock.MagicMock(name="my-org"), name="dream-team"),
22+
)
1723
@mock.patch("litmodels.io.cloud.sdk_upload_model")
1824
@mock.patch("litmodels.integrations.checkpoints.Auth")
19-
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing, tmp_path):
25+
def test_lightning_checkpoint_callback(
26+
mock_auth, mock_upload_model, mock_resolve_teamspace, mock_datetime_stamp, importing, with_model_name, tmp_path
27+
):
2028
if importing == "lightning":
2129
from lightning import Trainer
2230
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -31,20 +39,24 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing,
3139
# Validate inheritance
3240
assert issubclass(LitModelCheckpoint, ModelCheckpoint)
3341

34-
mock_upload_model.return_value.name = "org-name/teamspace/model-name"
42+
ckpt_args = {"model_name": "org-name/teamspace/model-name"} if with_model_name else {}
43+
expected_model_registry = ckpt_args.get("model_name", f"BoringModel_{LitModelCheckpoint._datetime_stamp}")
44+
mock_upload_model.return_value.name = expected_model_registry
3545

3646
trainer = Trainer(
3747
max_epochs=2,
38-
callbacks=LitModelCheckpoint(model_name="org-name/teamspace/model-name"),
48+
callbacks=LitModelCheckpoint(**ckpt_args),
3949
)
4050
trainer.fit(BoringModel())
4151

42-
# expected_path = model_path % str(tmpdir) if "%" in model_path else model_path
43-
assert mock_upload_model.call_count == 2
52+
assert mock_auth.call_count == 1
4453
assert mock_upload_model.call_args_list == [
45-
mock.call(name="org-name/teamspace/model-name", path=mock.ANY, progress_bar=True, cloud_account=None),
46-
mock.call(name="org-name/teamspace/model-name", path=mock.ANY, progress_bar=True, cloud_account=None),
54+
mock.call(name=expected_model_registry, path=mock.ANY, progress_bar=True, cloud_account=None),
55+
mock.call(name=expected_model_registry, path=mock.ANY, progress_bar=True, cloud_account=None),
4756
]
57+
called_name_related_mocks = 0 if with_model_name else 1
58+
mock_datetime_stamp.call_count == called_name_related_mocks
59+
mock_resolve_teamspace.call_count == called_name_related_mocks
4860

4961
# Verify paths match the expected pattern
5062
for call_args in mock_upload_model.call_args_list:

tests/integrations/test_cloud.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
from litmodels import download_model, upload_model
1313
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
1414

15-
from tests.integrations import _SKIP_IF_LIGHTNING_BELLOW_2_5_1, _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1
16-
17-
LIT_ORG = "lightning-ai"
18-
LIT_TEAMSPACE = "LitModels"
15+
from tests.integrations import (
16+
_SKIP_IF_LIGHTNING_BELLOW_2_5_1,
17+
_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1,
18+
LIT_ORG,
19+
LIT_TEAMSPACE,
20+
)
1921

2022

2123
def _prepare_variables(test_name: str) -> tuple[Teamspace, str, str]:

0 commit comments

Comments
 (0)