Skip to content

Commit a26d86d

Browse files
feat: customized model cleaning / prune registry (#98)
* feat: customized model cleaning * mock + testing * lightning-sdk >=0.2.11 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4c381c6 commit a26d86d

File tree

4 files changed

+82
-18
lines changed

4 files changed

+82
-18
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
lightning-sdk >=0.2.10
1+
lightning-sdk >=0.2.11
22
lightning-utilities

src/litmodels/integrations/checkpoints.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from litmodels import upload_model
1717
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
18-
from litmodels.io.cloud import _list_available_teamspaces
18+
from litmodels.io.cloud import _list_available_teamspaces, delete_model_version
1919

2020
if _LIGHTNING_AVAILABLE:
2121
from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint
@@ -47,6 +47,13 @@ class Action(StrEnum):
4747
REMOVE = "remove"
4848

4949

50+
class RemoveType(StrEnum):
51+
"""Enumeration of possible remove types for the ModelManager."""
52+
53+
LOCAL = "local"
54+
CLOUD = "cloud"
55+
56+
5057
class ModelManager:
5158
"""Manages uploads and removals with a single queue but separate counters."""
5259

@@ -94,10 +101,16 @@ def _worker_loop(self) -> None:
94101
finally:
95102
self.upload_count -= 1
96103
elif action == Action.REMOVE:
97-
trainer, filepath = detail
104+
filepath, trainer, registry_name = detail
98105
try:
99-
trainer.strategy.remove_checkpoint(filepath)
100-
rank_zero_debug(f"Removed file: {filepath}")
106+
if registry_name:
107+
rank_zero_debug(f"Removing from cloud: {filepath}")
108+
# Remove from the cloud
109+
version = os.path.splitext(os.path.basename(filepath))[0]
110+
delete_model_version(name=registry_name, version=version)
111+
if trainer:
112+
rank_zero_debug(f"Removed local file: {filepath}")
113+
trainer.strategy.remove_checkpoint(filepath)
101114
except Exception as ex:
102115
rank_zero_warn(f"Removal failed {filepath}: {ex}")
103116
finally:
@@ -112,10 +125,12 @@ def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata:
112125
self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata)))
113126
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")
114127

115-
def queue_remove(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
128+
def queue_remove(
129+
self, filepath: Union[str, Path], trainer: Optional["pl.Trainer"] = None, registry_name: Optional[str] = None
130+
) -> None:
116131
"""Queue a removal task."""
117132
self.remove_count += 1
118-
self.task_queue.put((Action.REMOVE, (trainer, filepath)))
133+
self.task_queue.put((Action.REMOVE, (filepath, trainer, registry_name)))
119134
rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})")
120135

121136
def shutdown(self) -> None:
@@ -133,11 +148,14 @@ class LitModelCheckpointMixin(ABC):
133148
model_registry: Optional[str] = None
134149
_model_manager: ModelManager
135150

136-
def __init__(self, model_registry: Optional[str], clear_all_local: bool = False) -> None:
151+
def __init__(
152+
self, model_registry: Optional[str], keep_all_uploaded: bool = False, clear_all_local: bool = False
153+
) -> None:
137154
"""Initialize with model name.
138155
139156
Args:
140157
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
158+
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
141159
clear_all_local: Whether to clear local models after uploading to the cloud.
142160
"""
143161
if not model_registry:
@@ -147,6 +165,7 @@ def __init__(self, model_registry: Optional[str], clear_all_local: bool = False)
147165
self._datetime_stamp = datetime.now().strftime("%Y%m%d-%H%M")
148166
# remove any / from beginning and end of the name
149167
self.model_registry = model_registry.strip("/") if model_registry else None
168+
self._keep_all_uploaded = keep_all_uploaded
150169
self._clear_all_local = clear_all_local
151170

152171
try: # authenticate before anything else starts
@@ -178,15 +197,18 @@ def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metad
178197
# Add to queue instead of uploading directly
179198
get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata)
180199
if self._clear_all_local:
181-
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)
200+
get_model_manager().queue_remove(filepath=filepath, trainer=trainer)
182201

183202
@rank_zero_only
184203
def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
185204
"""Remove the local version of the model if requested."""
186-
if self._clear_all_local:
205+
get_model_manager().queue_remove(
206+
filepath=filepath,
187207
# skip the local removal we put it in the queue right after the upload
188-
return
189-
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)
208+
trainer=None if self._clear_all_local else trainer,
209+
# skip the cloud removal if we keep all uploaded models
210+
registry_name=None if self._keep_all_uploaded else self.model_registry,
211+
)
190212

191213
def default_model_name(self, pl_model: "pl.LightningModule") -> str:
192214
"""Generate a default model name based on the class name and timestamp."""
@@ -234,6 +256,7 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
234256
235257
Args:
236258
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
259+
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
237260
clear_all_local: Whether to clear local models after uploading to the cloud.
238261
*args: Additional arguments to pass to the parent class.
239262
**kwargs: Additional keyword arguments to pass to the parent class.
@@ -244,6 +267,7 @@ def __init__(
244267
*args: Any,
245268
model_name: Optional[str] = None,
246269
model_registry: Optional[str] = None,
270+
keep_all_uploaded: bool = False,
247271
clear_all_local: bool = False,
248272
**kwargs: Any,
249273
) -> None:
@@ -255,7 +279,10 @@ def __init__(
255279
" Please use 'model_registry' instead."
256280
)
257281
LitModelCheckpointMixin.__init__(
258-
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
282+
self,
283+
model_registry=model_registry or model_name,
284+
keep_all_uploaded=keep_all_uploaded,
285+
clear_all_local=clear_all_local,
259286
)
260287

261288
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
@@ -288,6 +315,7 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
288315
289316
Args:
290317
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
318+
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
291319
clear_all_local: Whether to clear local models after uploading to the cloud.
292320
args: Additional arguments to pass to the parent class.
293321
kwargs: Additional keyword arguments to pass to the parent class.
@@ -298,6 +326,7 @@ def __init__(
298326
*args: Any,
299327
model_name: Optional[str] = None,
300328
model_registry: Optional[str] = None,
329+
keep_all_uploaded: bool = False,
301330
clear_all_local: bool = False,
302331
**kwargs: Any,
303332
) -> None:
@@ -309,7 +338,10 @@ def __init__(
309338
" Please use 'model_registry' instead."
310339
)
311340
LitModelCheckpointMixin.__init__(
312-
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
341+
self,
342+
model_registry=model_registry or model_name,
343+
keep_all_uploaded=keep_all_uploaded,
344+
clear_all_local=clear_all_local,
313345
)
314346

315347
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:

src/litmodels/io/cloud.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
99
from lightning_sdk.models import _extend_model_name_with_teamspace, _parse_org_teamspace_model_version
10+
from lightning_sdk.models import delete_model as sdk_delete_model
1011
from lightning_sdk.models import download_model as sdk_download_model
1112
from lightning_sdk.models import upload_model as sdk_upload_model
1213

@@ -123,3 +124,17 @@ def _list_available_teamspaces() -> Dict[str, dict]:
123124
else:
124125
raise RuntimeError(f"Unknown organization type {ts.organization_type}")
125126
return teamspaces
127+
128+
129+
def delete_model_version(
130+
name: str,
131+
version: Optional[str] = None,
132+
) -> None:
133+
"""Delete a model version from the model store.
134+
135+
Args:
136+
name: Name of the model to delete. Must be in the format 'organization/teamspace/modelname'
137+
where entity is either your username or the name of an organization you are part of.
138+
version: Version of the model to delete. If not provided, all versions will be deleted.
139+
"""
140+
sdk_delete_model(name=f"{name}:{version}")

tests/integrations/test_checkpoints.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,20 @@
1919
"model_name", [None, "org-name/teamspace/model-name", "model-in-studio", "model-user-only-project"]
2020
)
2121
@pytest.mark.parametrize("clear_all_local", [True, False])
22+
@pytest.mark.parametrize("keep_all_uploaded", [True, False])
23+
@mock.patch("litmodels.io.cloud.sdk_delete_model")
2224
@mock.patch("litmodels.io.cloud.sdk_upload_model")
2325
@mock.patch("litmodels.integrations.checkpoints.Auth")
2426
def test_lightning_checkpoint_callback(
25-
mock_auth, mock_upload_model, monkeypatch, importing, model_name, clear_all_local, tmp_path
27+
mock_auth,
28+
mock_upload_model,
29+
mock_delete_model,
30+
monkeypatch,
31+
importing,
32+
model_name,
33+
clear_all_local,
34+
keep_all_uploaded,
35+
tmp_path,
2636
):
2737
if importing == "lightning":
2838
from lightning.pytorch import Trainer
@@ -40,7 +50,7 @@ def test_lightning_checkpoint_callback(
4050
# Validate inheritance
4151
assert issubclass(LitModelCheckpoint, ModelCheckpoint)
4252

43-
ckpt_args = {"clear_all_local": clear_all_local}
53+
ckpt_args = {"clear_all_local": clear_all_local, "keep_all_uploaded": keep_all_uploaded}
4454
if model_name:
4555
ckpt_args.update({"model_registry": model_name})
4656

@@ -98,8 +108,15 @@ def test_lightning_checkpoint_callback(
98108
)
99109
for v in ("epoch=0-step=64", "epoch=1-step=128")
100110
]
101-
expected_removals = 2 if clear_all_local else 1
102-
assert mock_remove_ckpt.call_count == expected_removals
111+
expected_local_removals = 2 if clear_all_local else 1
112+
assert mock_remove_ckpt.call_count == expected_local_removals
113+
114+
expected_cloud_removals = 0 if keep_all_uploaded else 1
115+
assert mock_delete_model.call_count == expected_cloud_removals
116+
if expected_cloud_removals:
117+
mock_delete_model.assert_called_once_with(
118+
name=f"{expected_org}/{expected_teamspace}/{expected_model}:epoch=0-step=64"
119+
)
103120

104121
# Verify paths match the expected pattern
105122
for call_args in mock_upload_model.call_args_list:

0 commit comments

Comments
 (0)