Skip to content

Commit 38eab58

Browse files
feat: customised model cleaning / clear all local (#94)
* feat: customized model cleaning * model_registry or model_name * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5320a99 commit 38eab58

File tree

4 files changed

+82
-28
lines changed

4 files changed

+82
-28
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ trainer = Trainer(
226226
callbacks=[
227227
LightningModelCheckpoint(
228228
# Define the model name - this should be unique to your model
229-
model_name="<organization>/<teamspace>/<model-name>",
229+
model_registry="<organization>/<teamspace>/<model-name>",
230230
)
231231
],
232232
)

examples/train-model-with-lightning-callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
if __name__ == "__main__":
1616
trainer = Trainer(
1717
max_epochs=2,
18-
callbacks=LightningModelCheckpoint(model_name=MY_MODEL_NAME),
18+
callbacks=LightningModelCheckpoint(model_registry=MY_MODEL_NAME),
1919
)
2020
trainer.fit(BoringModel())

src/litmodels/integrations/checkpoints.py

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from abc import ABC
66
from datetime import datetime
77
from functools import lru_cache
8-
from typing import TYPE_CHECKING, Any, Optional
8+
from pathlib import Path
9+
from typing import TYPE_CHECKING, Any, Optional, Union
910

1011
from lightning_sdk.lightning_cloud.login import Auth
1112
from lightning_sdk.utils.resolve import _resolve_teamspace
@@ -105,13 +106,13 @@ def _worker_loop(self) -> None:
105106
rank_zero_warn(f"Unknown task: {task}")
106107
self.task_queue.task_done()
107108

108-
def queue_upload(self, registry_name: str, filepath: str, metadata: Optional[dict] = None) -> None:
109+
def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata: Optional[dict] = None) -> None:
109110
"""Queue an upload task."""
110111
self.upload_count += 1
111112
self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata)))
112113
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")
113114

114-
def queue_remove(self, trainer: "pl.Trainer", filepath: str) -> None:
115+
def queue_remove(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
115116
"""Queue a removal task."""
116117
self.remove_count += 1
117118
self.task_queue.put((Action.REMOVE, (trainer, filepath)))
@@ -132,15 +133,21 @@ class LitModelCheckpointMixin(ABC):
132133
model_registry: Optional[str] = None
133134
_model_manager: ModelManager
134135

135-
def __init__(self, model_name: Optional[str]) -> None:
136-
"""Initialize with model name."""
137-
if not model_name:
136+
def __init__(self, model_registry: Optional[str], clear_all_local: bool = False) -> None:
137+
"""Initialize with model name.
138+
139+
Args:
140+
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
141+
clear_all_local: Whether to clear local models after uploading to the cloud.
142+
"""
143+
if not model_registry:
138144
rank_zero_warn(
139145
"The model is not defined so we will continue with LightningModule names and timestamp of now"
140146
)
141147
self._datetime_stamp = datetime.now().strftime("%Y%m%d-%H%M")
142148
# remove any / from beginning and end of the name
143-
self.model_registry = model_name.strip("/") if model_name else None
149+
self.model_registry = model_registry.strip("/") if model_registry else None
150+
self._clear_all_local = clear_all_local
144151

145152
try: # authenticate before anything else starts
146153
Auth().authenticate()
@@ -150,7 +157,7 @@ def __init__(self, model_name: Optional[str]) -> None:
150157
self._model_manager = ModelManager()
151158

152159
@rank_zero_only
153-
def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
160+
def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metadata: Optional[dict] = None) -> None:
154161
if not self.model_registry:
155162
raise RuntimeError(
156163
"Model name is not specified neither updated by `setup` method via Trainer."
@@ -170,11 +177,16 @@ def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
170177
metadata.update({"litModels_integration": ckpt_class.__name__})
171178
# Add to queue instead of uploading directly
172179
get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata)
180+
if self._clear_all_local:
181+
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)
173182

174183
@rank_zero_only
175-
def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None:
184+
def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
176185
"""Remove the local version of the model if requested."""
177-
get_model_manager().queue_remove(trainer, filepath)
186+
if self._clear_all_local:
187+
# skip the local removal we put it in the queue right after the upload
188+
return
189+
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)
178190

179191
def default_model_name(self, pl_model: "pl.LightningModule") -> str:
180192
"""Generate a default model name based on the class name and timestamp."""
@@ -221,15 +233,30 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
221233
"""Lightning ModelCheckpoint with LitModel support.
222234
223235
Args:
224-
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
225-
args: Additional arguments to pass to the parent class.
226-
kwargs: Additional keyword arguments to pass to the parent class.
236+
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
237+
clear_all_local: Whether to clear local models after uploading to the cloud.
238+
*args: Additional arguments to pass to the parent class.
239+
**kwargs: Additional keyword arguments to pass to the parent class.
227240
"""
228241

229-
def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any) -> None:
242+
def __init__(
243+
self,
244+
*args: Any,
245+
model_name: Optional[str] = None,
246+
model_registry: Optional[str] = None,
247+
clear_all_local: bool = False,
248+
**kwargs: Any,
249+
) -> None:
230250
"""Initialize the checkpoint with model name and other parameters."""
231251
_LightningModelCheckpoint.__init__(self, *args, **kwargs)
232-
LitModelCheckpointMixin.__init__(self, model_name)
252+
if model_name is not None:
253+
rank_zero_warn(
254+
"The 'model_name' argument is deprecated and will be removed in a future version."
255+
" Please use 'model_registry' instead."
256+
)
257+
LitModelCheckpointMixin.__init__(
258+
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
259+
)
233260

234261
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
235262
"""Setup the checkpoint callback."""
@@ -240,7 +267,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
240267
"""Extend the save checkpoint method to upload the model."""
241268
_LightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
242269
if trainer.is_global_zero: # Only upload from the main process
243-
self._upload_model(filepath)
270+
self._upload_model(trainer=trainer, filepath=filepath)
244271

245272
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
246273
"""Extend the on_fit_end method to ensure all uploads are completed."""
@@ -251,7 +278,7 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
251278
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
252279
"""Extend the remove checkpoint method to remove the model from the registry."""
253280
if trainer.is_global_zero: # Only remove from the main process
254-
self._remove_model(trainer, filepath)
281+
self._remove_model(trainer=trainer, filepath=filepath)
255282

256283

257284
if _PYTORCHLIGHTNING_AVAILABLE:
@@ -260,15 +287,30 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
260287
"""PyTorch Lightning ModelCheckpoint with LitModel support.
261288
262289
Args:
263-
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
290+
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
291+
clear_all_local: Whether to clear local models after uploading to the cloud.
264292
args: Additional arguments to pass to the parent class.
265293
kwargs: Additional keyword arguments to pass to the parent class.
266294
"""
267295

268-
def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any) -> None:
296+
def __init__(
297+
self,
298+
*args: Any,
299+
model_name: Optional[str] = None,
300+
model_registry: Optional[str] = None,
301+
clear_all_local: bool = False,
302+
**kwargs: Any,
303+
) -> None:
269304
"""Initialize the checkpoint with model name and other parameters."""
270305
_PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
271-
LitModelCheckpointMixin.__init__(self, model_name)
306+
if model_name is not None:
307+
rank_zero_warn(
308+
"The 'model_name' argument is deprecated and will be removed in a future version."
309+
" Please use 'model_registry' instead."
310+
)
311+
LitModelCheckpointMixin.__init__(
312+
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
313+
)
272314

273315
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
274316
"""Setup the checkpoint callback."""
@@ -279,7 +321,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
279321
"""Extend the save checkpoint method to upload the model."""
280322
_PytorchLightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
281323
if trainer.is_global_zero: # Only upload from the main process
282-
self._upload_model(filepath)
324+
self._upload_model(trainer=trainer, filepath=filepath)
283325

284326
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
285327
"""Extend the on_fit_end method to ensure all uploads are completed."""
@@ -290,4 +332,4 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
290332
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
291333
"""Extend the remove checkpoint method to remove the model from the registry."""
292334
if trainer.is_global_zero: # Only remove from the main process
293-
self._remove_model(trainer, filepath)
335+
self._remove_model(trainer=trainer, filepath=filepath)

tests/integrations/test_checkpoints.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
@pytest.mark.parametrize(
1919
"model_name", [None, "org-name/teamspace/model-name", "model-in-studio", "model-user-only-project"]
2020
)
21+
@pytest.mark.parametrize("clear_all_local", [True, False])
2122
@mock.patch("litmodels.io.cloud.sdk_upload_model")
2223
@mock.patch("litmodels.integrations.checkpoints.Auth")
23-
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch, importing, model_name, tmp_path):
24+
def test_lightning_checkpoint_callback(
25+
mock_auth, mock_upload_model, monkeypatch, importing, model_name, clear_all_local, tmp_path
26+
):
2427
if importing == "lightning":
25-
from lightning import Trainer
28+
from lightning.pytorch import Trainer
2629
from lightning.pytorch.callbacks import ModelCheckpoint
2730
from lightning.pytorch.demos.boring_classes import BoringModel
2831

@@ -37,7 +40,10 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
3740
# Validate inheritance
3841
assert issubclass(LitModelCheckpoint, ModelCheckpoint)
3942

40-
ckpt_args = {"model_name": model_name} if model_name else {}
43+
ckpt_args = {"clear_all_local": clear_all_local}
44+
if model_name:
45+
ckpt_args.update({"model_registry": model_name})
46+
4147
all_model_registry = {
4248
"org-name/teamspace/model-name": {"org": "org-name", "teamspace": "teamspace", "model": "model-name"},
4349
"model-in-studio": {"org": "my-org", "teamspace": "dream-team", "model": "model-in-studio"},
@@ -71,10 +77,14 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
7177
mock.MagicMock(return_value={f"{expected_org}/{expected_teamspace}": {}}),
7278
)
7379

80+
# mocking the trainer delete checkpoint removal
81+
mock_remove_ckpt = mock.Mock()
82+
# setting the Trainer and custom checkpointing
7483
trainer = Trainer(
7584
max_epochs=2,
7685
callbacks=LitModelCheckpoint(**ckpt_args),
7786
)
87+
trainer.strategy.remove_checkpoint = mock_remove_ckpt
7888
trainer.fit(BoringModel())
7989

8090
assert mock_auth.call_count == 1
@@ -88,6 +98,8 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
8898
)
8999
for v in ("epoch=0-step=64", "epoch=1-step=128")
90100
]
101+
expected_removals = 2 if clear_all_local else 1
102+
assert mock_remove_ckpt.call_count == expected_removals
91103

92104
# Verify paths match the expected pattern
93105
for call_args in mock_upload_model.call_args_list:
@@ -109,6 +121,6 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing):
109121
elif importing == "pytorch_lightning":
110122
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
111123

112-
ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
124+
ckpt = LitModelCheckpoint(model_registry="org-name/teamspace/model-name")
113125
assert mock_auth.call_count == 1
114126
pickle.dumps(ckpt)

0 commit comments

Comments
 (0)