Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
lightning-sdk >=0.2.10
lightning-sdk >=0.2.11
lightning-utilities
58 changes: 45 additions & 13 deletions src/litmodels/integrations/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from litmodels import upload_model
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
from litmodels.io.cloud import _list_available_teamspaces
from litmodels.io.cloud import _list_available_teamspaces, delete_model_version

if _LIGHTNING_AVAILABLE:
from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint
Expand Down Expand Up @@ -47,6 +47,13 @@ class Action(StrEnum):
REMOVE = "remove"


class RemoveType(StrEnum):
"""Enumeration of possible remove types for the ModelManager."""

LOCAL = "local"
CLOUD = "cloud"


class ModelManager:
"""Manages uploads and removals with a single queue but separate counters."""

Expand Down Expand Up @@ -94,10 +101,16 @@ def _worker_loop(self) -> None:
finally:
self.upload_count -= 1
elif action == Action.REMOVE:
trainer, filepath = detail
filepath, trainer, registry_name = detail
try:
trainer.strategy.remove_checkpoint(filepath)
rank_zero_debug(f"Removed file: {filepath}")
if registry_name:
rank_zero_debug(f"Removing from cloud: {filepath}")
# Remove from the cloud
version = os.path.splitext(os.path.basename(filepath))[0]
delete_model_version(name=registry_name, version=version)
if trainer:
rank_zero_debug(f"Removed local file: {filepath}")
trainer.strategy.remove_checkpoint(filepath)
except Exception as ex:
rank_zero_warn(f"Removal failed {filepath}: {ex}")
finally:
Expand All @@ -112,10 +125,12 @@ def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata:
self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata)))
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")

def queue_remove(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
def queue_remove(
self, filepath: Union[str, Path], trainer: Optional["pl.Trainer"] = None, registry_name: Optional[str] = None
) -> None:
"""Queue a removal task."""
self.remove_count += 1
self.task_queue.put((Action.REMOVE, (trainer, filepath)))
self.task_queue.put((Action.REMOVE, (filepath, trainer, registry_name)))
rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})")

def shutdown(self) -> None:
Expand All @@ -133,11 +148,14 @@ class LitModelCheckpointMixin(ABC):
model_registry: Optional[str] = None
_model_manager: ModelManager

def __init__(self, model_registry: Optional[str], clear_all_local: bool = False) -> None:
def __init__(
self, model_registry: Optional[str], keep_all_uploaded: bool = False, clear_all_local: bool = False
) -> None:
"""Initialize with model name.

Args:
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
clear_all_local: Whether to clear local models after uploading to the cloud.
"""
if not model_registry:
Expand All @@ -147,6 +165,7 @@ def __init__(self, model_registry: Optional[str], clear_all_local: bool = False)
self._datetime_stamp = datetime.now().strftime("%Y%m%d-%H%M")
# remove any / from beginning and end of the name
self.model_registry = model_registry.strip("/") if model_registry else None
self._keep_all_uploaded = keep_all_uploaded
self._clear_all_local = clear_all_local

try: # authenticate before anything else starts
Expand Down Expand Up @@ -178,15 +197,18 @@ def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metad
# Add to queue instead of uploading directly
get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata)
if self._clear_all_local:
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)
get_model_manager().queue_remove(filepath=filepath, trainer=trainer)

@rank_zero_only
def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
"""Remove the local version of the model if requested."""
if self._clear_all_local:
get_model_manager().queue_remove(
filepath=filepath,
# skip the local removal we put it in the queue right after the upload
return
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)
trainer=None if self._clear_all_local else trainer,
# skip the cloud removal if we keep all uploaded models
registry_name=None if self._keep_all_uploaded else self.model_registry,
)

def default_model_name(self, pl_model: "pl.LightningModule") -> str:
"""Generate a default model name based on the class name and timestamp."""
Expand Down Expand Up @@ -234,6 +256,7 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin

Args:
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
clear_all_local: Whether to clear local models after uploading to the cloud.
*args: Additional arguments to pass to the parent class.
**kwargs: Additional keyword arguments to pass to the parent class.
Expand All @@ -244,6 +267,7 @@ def __init__(
*args: Any,
model_name: Optional[str] = None,
model_registry: Optional[str] = None,
keep_all_uploaded: bool = False,
clear_all_local: bool = False,
**kwargs: Any,
) -> None:
Expand All @@ -255,7 +279,10 @@ def __init__(
" Please use 'model_registry' instead."
)
LitModelCheckpointMixin.__init__(
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
self,
model_registry=model_registry or model_name,
keep_all_uploaded=keep_all_uploaded,
clear_all_local=clear_all_local,
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
Expand Down Expand Up @@ -288,6 +315,7 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning

Args:
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
clear_all_local: Whether to clear local models after uploading to the cloud.
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
Expand All @@ -298,6 +326,7 @@ def __init__(
*args: Any,
model_name: Optional[str] = None,
model_registry: Optional[str] = None,
keep_all_uploaded: bool = False,
clear_all_local: bool = False,
**kwargs: Any,
) -> None:
Expand All @@ -309,7 +338,10 @@ def __init__(
" Please use 'model_registry' instead."
)
LitModelCheckpointMixin.__init__(
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
self,
model_registry=model_registry or model_name,
keep_all_uploaded=keep_all_uploaded,
clear_all_local=clear_all_local,
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
Expand Down
15 changes: 15 additions & 0 deletions src/litmodels/io/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
from lightning_sdk.models import _extend_model_name_with_teamspace, _parse_org_teamspace_model_version
from lightning_sdk.models import delete_model as sdk_delete_model
from lightning_sdk.models import download_model as sdk_download_model
from lightning_sdk.models import upload_model as sdk_upload_model

Expand Down Expand Up @@ -123,3 +124,17 @@ def _list_available_teamspaces() -> Dict[str, dict]:
else:
raise RuntimeError(f"Unknown organization type {ts.organization_type}")
return teamspaces


def delete_model_version(
name: str,
version: Optional[str] = None,
) -> None:
"""Delete a model version from the model store.

Args:
name: Name of the model to delete. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
version: Version of the model to delete. If not provided, all versions will be deleted.
"""
sdk_delete_model(name=f"{name}:{version}")
25 changes: 21 additions & 4 deletions tests/integrations/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@
"model_name", [None, "org-name/teamspace/model-name", "model-in-studio", "model-user-only-project"]
)
@pytest.mark.parametrize("clear_all_local", [True, False])
@pytest.mark.parametrize("keep_all_uploaded", [True, False])
@mock.patch("litmodels.io.cloud.sdk_delete_model")
@mock.patch("litmodels.io.cloud.sdk_upload_model")
@mock.patch("litmodels.integrations.checkpoints.Auth")
def test_lightning_checkpoint_callback(
mock_auth, mock_upload_model, monkeypatch, importing, model_name, clear_all_local, tmp_path
mock_auth,
mock_upload_model,
mock_delete_model,
monkeypatch,
importing,
model_name,
clear_all_local,
keep_all_uploaded,
tmp_path,
):
if importing == "lightning":
from lightning.pytorch import Trainer
Expand All @@ -40,7 +50,7 @@ def test_lightning_checkpoint_callback(
# Validate inheritance
assert issubclass(LitModelCheckpoint, ModelCheckpoint)

ckpt_args = {"clear_all_local": clear_all_local}
ckpt_args = {"clear_all_local": clear_all_local, "keep_all_uploaded": keep_all_uploaded}
if model_name:
ckpt_args.update({"model_registry": model_name})

Expand Down Expand Up @@ -98,8 +108,15 @@ def test_lightning_checkpoint_callback(
)
for v in ("epoch=0-step=64", "epoch=1-step=128")
]
expected_removals = 2 if clear_all_local else 1
assert mock_remove_ckpt.call_count == expected_removals
expected_local_removals = 2 if clear_all_local else 1
assert mock_remove_ckpt.call_count == expected_local_removals

expected_cloud_removals = 0 if keep_all_uploaded else 1
assert mock_delete_model.call_count == expected_cloud_removals
if expected_cloud_removals:
mock_delete_model.assert_called_once_with(
name=f"{expected_org}/{expected_teamspace}/{expected_model}:epoch=0-step=64"
)

# Verify paths match the expected pattern
for call_args in mock_upload_model.call_args_list:
Expand Down
Loading