From a2f0e268171fa70676e03ff296c77ed7e50cb3d2 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Wed, 12 Mar 2025 22:27:23 +0100 Subject: [PATCH 1/2] ckpt: rank_zero_only --- src/litmodels/__about__.py | 2 +- src/litmodels/integrations/checkpoints.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/litmodels/__about__.py b/src/litmodels/__about__.py index 5d0b459..b6162ed 100644 --- a/src/litmodels/__about__.py +++ b/src/litmodels/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.1" +__version__ = "0.1.2rc" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 7d76852..9dd7c53 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -1,5 +1,6 @@ from typing import Any, Type, TypeVar, cast +from lightning.fabric.utilities import rank_zero_only from lightning_sdk.lightning_cloud.login import Auth from litmodels import upload_model @@ -48,12 +49,16 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: except Exception: raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) + @rank_zero_only + def _upload_model(self, filepath: str) -> None: # todo: uploading on background so training does nt stops # todo: use filename as version but need to validate that such version does not exists yet upload_model(name=self.model_name, model=filepath) + def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + self._upload_model(filepath) + return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) From 11360610714dd5b957ee5b3d06d13944ca4d7f19 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 13 Mar 2025 08:08:04 +0100 Subject: [PATCH 2/2] lightning_utilities --- src/litmodels/integrations/checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index 9dd7c53..d782721 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -1,7 +1,7 @@ from typing import Any, Type, TypeVar, cast -from lightning.fabric.utilities import rank_zero_only from lightning_sdk.lightning_cloud.login import Auth +from lightning_utilities.core.rank_zero import rank_zero_only from litmodels import upload_model from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE