Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ trainer = Trainer(
callbacks=[
LightningModelCheckpoint(
# Define the model name - this should be unique to your model
model_name="<organization>/<teamspace>/<model-name>",
model_registry="<organization>/<teamspace>/<model-name>",
)
],
)
Expand Down
2 changes: 1 addition & 1 deletion examples/train-model-with-lightning-callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
if __name__ == "__main__":
trainer = Trainer(
max_epochs=2,
callbacks=LightningModelCheckpoint(model_name=MY_MODEL_NAME),
callbacks=LightningModelCheckpoint(model_registry=MY_MODEL_NAME),
)
trainer.fit(BoringModel())
86 changes: 64 additions & 22 deletions src/litmodels/integrations/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from abc import ABC
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

from lightning_sdk.lightning_cloud.login import Auth
from lightning_sdk.utils.resolve import _resolve_teamspace
Expand Down Expand Up @@ -105,13 +106,13 @@ def _worker_loop(self) -> None:
rank_zero_warn(f"Unknown task: {task}")
self.task_queue.task_done()

def queue_upload(self, registry_name: str, filepath: str, metadata: Optional[dict] = None) -> None:
def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata: Optional[dict] = None) -> None:
"""Queue an upload task."""
self.upload_count += 1
self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata)))
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")

def queue_remove(self, trainer: "pl.Trainer", filepath: str) -> None:
def queue_remove(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
"""Queue a removal task."""
self.remove_count += 1
self.task_queue.put((Action.REMOVE, (trainer, filepath)))
Expand All @@ -132,15 +133,21 @@ class LitModelCheckpointMixin(ABC):
model_registry: Optional[str] = None
_model_manager: ModelManager

def __init__(self, model_name: Optional[str]) -> None:
"""Initialize with model name."""
if not model_name:
def __init__(self, model_registry: Optional[str], clear_all_local: bool = False) -> None:
"""Initialize with model name.

Args:
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
clear_all_local: Whether to clear local models after uploading to the cloud.
"""
if not model_registry:
rank_zero_warn(
"The model is not defined so we will continue with LightningModule names and timestamp of now"
)
self._datetime_stamp = datetime.now().strftime("%Y%m%d-%H%M")
# remove any / from beginning and end of the name
self.model_registry = model_name.strip("/") if model_name else None
self.model_registry = model_registry.strip("/") if model_registry else None
self._clear_all_local = clear_all_local

try: # authenticate before anything else starts
Auth().authenticate()
Expand All @@ -150,7 +157,7 @@ def __init__(self, model_name: Optional[str]) -> None:
self._model_manager = ModelManager()

@rank_zero_only
def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metadata: Optional[dict] = None) -> None:
if not self.model_registry:
raise RuntimeError(
"Model name is not specified neither updated by `setup` method via Trainer."
Expand All @@ -170,11 +177,16 @@ def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
metadata.update({"litModels_integration": ckpt_class.__name__})
# Add to queue instead of uploading directly
get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata)
if self._clear_all_local:
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)

@rank_zero_only
def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None:
def _remove_model(self, trainer: "pl.Trainer", filepath: Union[str, Path]) -> None:
"""Remove the local version of the model if requested."""
get_model_manager().queue_remove(trainer, filepath)
if self._clear_all_local:
# skip the local removal we put it in the queue right after the upload
return
get_model_manager().queue_remove(trainer=trainer, filepath=filepath)

def default_model_name(self, pl_model: "pl.LightningModule") -> str:
"""Generate a default model name based on the class name and timestamp."""
Expand Down Expand Up @@ -221,15 +233,30 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
"""Lightning ModelCheckpoint with LitModel support.

Args:
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
clear_all_local: Whether to clear local models after uploading to the cloud.
*args: Additional arguments to pass to the parent class.
**kwargs: Additional keyword arguments to pass to the parent class.
"""

def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any) -> None:
def __init__(
self,
*args: Any,
model_name: Optional[str] = None,
model_registry: Optional[str] = None,
clear_all_local: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the checkpoint with model name and other parameters."""
_LightningModelCheckpoint.__init__(self, *args, **kwargs)
LitModelCheckpointMixin.__init__(self, model_name)
if model_name is not None:
rank_zero_warn(
"The 'model_name' argument is deprecated and will be removed in a future version."
" Please use 'model_registry' instead."
)
LitModelCheckpointMixin.__init__(
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Setup the checkpoint callback."""
Expand All @@ -240,7 +267,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Extend the save checkpoint method to upload the model."""
_LightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
if trainer.is_global_zero: # Only upload from the main process
self._upload_model(filepath)
self._upload_model(trainer=trainer, filepath=filepath)

def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Extend the on_fit_end method to ensure all uploads are completed."""
Expand All @@ -251,7 +278,7 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Extend the remove checkpoint method to remove the model from the registry."""
if trainer.is_global_zero: # Only remove from the main process
self._remove_model(trainer, filepath)
self._remove_model(trainer=trainer, filepath=filepath)


if _PYTORCHLIGHTNING_AVAILABLE:
Expand All @@ -260,15 +287,30 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
"""PyTorch Lightning ModelCheckpoint with LitModel support.

Args:
model_name: Name of the model to upload in format 'organization/teamspace/modelname'
model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
clear_all_local: Whether to clear local models after uploading to the cloud.
args: Additional arguments to pass to the parent class.
kwargs: Additional keyword arguments to pass to the parent class.
"""

def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any) -> None:
def __init__(
self,
*args: Any,
model_name: Optional[str] = None,
model_registry: Optional[str] = None,
clear_all_local: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the checkpoint with model name and other parameters."""
_PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
LitModelCheckpointMixin.__init__(self, model_name)
if model_name is not None:
rank_zero_warn(
"The 'model_name' argument is deprecated and will be removed in a future version."
" Please use 'model_registry' instead."
)
LitModelCheckpointMixin.__init__(
self, model_registry=model_registry or model_name, clear_all_local=clear_all_local
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Setup the checkpoint callback."""
Expand All @@ -279,7 +321,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Extend the save checkpoint method to upload the model."""
_PytorchLightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
if trainer.is_global_zero: # Only upload from the main process
self._upload_model(filepath)
self._upload_model(trainer=trainer, filepath=filepath)

def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Extend the on_fit_end method to ensure all uploads are completed."""
Expand All @@ -290,4 +332,4 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
"""Extend the remove checkpoint method to remove the model from the registry."""
if trainer.is_global_zero: # Only remove from the main process
self._remove_model(trainer, filepath)
self._remove_model(trainer=trainer, filepath=filepath)
20 changes: 16 additions & 4 deletions tests/integrations/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
@pytest.mark.parametrize(
"model_name", [None, "org-name/teamspace/model-name", "model-in-studio", "model-user-only-project"]
)
@pytest.mark.parametrize("clear_all_local", [True, False])
@mock.patch("litmodels.io.cloud.sdk_upload_model")
@mock.patch("litmodels.integrations.checkpoints.Auth")
def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch, importing, model_name, tmp_path):
def test_lightning_checkpoint_callback(
mock_auth, mock_upload_model, monkeypatch, importing, model_name, clear_all_local, tmp_path
):
if importing == "lightning":
from lightning import Trainer
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel

Expand All @@ -37,7 +40,10 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
# Validate inheritance
assert issubclass(LitModelCheckpoint, ModelCheckpoint)

ckpt_args = {"model_name": model_name} if model_name else {}
ckpt_args = {"clear_all_local": clear_all_local}
if model_name:
ckpt_args.update({"model_registry": model_name})

all_model_registry = {
"org-name/teamspace/model-name": {"org": "org-name", "teamspace": "teamspace", "model": "model-name"},
"model-in-studio": {"org": "my-org", "teamspace": "dream-team", "model": "model-in-studio"},
Expand Down Expand Up @@ -71,10 +77,14 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
mock.MagicMock(return_value={f"{expected_org}/{expected_teamspace}": {}}),
)

# mocking the trainer delete checkpoint removal
mock_remove_ckpt = mock.Mock()
# setting the Trainer and custom checkpointing
trainer = Trainer(
max_epochs=2,
callbacks=LitModelCheckpoint(**ckpt_args),
)
trainer.strategy.remove_checkpoint = mock_remove_ckpt
trainer.fit(BoringModel())

assert mock_auth.call_count == 1
Expand All @@ -88,6 +98,8 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
)
for v in ("epoch=0-step=64", "epoch=1-step=128")
]
expected_removals = 2 if clear_all_local else 1
assert mock_remove_ckpt.call_count == expected_removals

# Verify paths match the expected pattern
for call_args in mock_upload_model.call_args_list:
Expand All @@ -109,6 +121,6 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing):
elif importing == "pytorch_lightning":
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint

ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
ckpt = LitModelCheckpoint(model_registry="org-name/teamspace/model-name")
assert mock_auth.call_count == 1
pickle.dumps(ckpt)
Loading