diff --git a/src/litmodels/__about__.py b/src/litmodels/__about__.py index 5d0b459..b6162ed 100644 --- a/src/litmodels/__about__.py +++ b/src/litmodels/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.1" +__version__ = "0.1.2rc" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 7d76852..d782721 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -1,6 +1,7 @@ from typing import Any, Type, TypeVar, cast from lightning_sdk.lightning_cloud.login import Auth +from lightning_utilities.core.rank_zero import rank_zero_only from litmodels import upload_model from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE @@ -48,12 +49,16 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: except Exception: raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) + @rank_zero_only + def _upload_model(self, filepath: str) -> None: # todo: uploading on background so training does nt stops # todo: use filename as version but need to validate that such version does not exists yet upload_model(name=self.model_name, model=filepath) + def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + self._upload_model(filepath) + return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate)