Skip to content

Commit a2f0e26

Browse files
committed
ckpt: rank_zero_only
1 parent b7950b6 commit a2f0e26

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/litmodels/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.1"
1+
__version__ = "0.1.2rc"
22
__author__ = "Lightning-AI et al."
33
__author_email__ = "[email protected]"
44
__license__ = "Apache-2.0"

src/litmodels/integrations/checkpoints.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Type, TypeVar, cast
22

3+
from lightning.fabric.utilities import rank_zero_only
34
from lightning_sdk.lightning_cloud.login import Auth
45

56
from litmodels import upload_model
@@ -48,12 +49,16 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
4849
except Exception:
4950
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
5051

51-
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
52-
super()._save_checkpoint(trainer, filepath)
52+
@rank_zero_only
53+
def _upload_model(self, filepath: str) -> None:
5354
# todo: uploading on background so training does nt stops
5455
# todo: use filename as version but need to validate that such version does not exists yet
5556
upload_model(name=self.model_name, model=filepath)
5657

58+
def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None:
59+
super()._save_checkpoint(trainer, filepath)
60+
self._upload_model(filepath)
61+
5762
return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate)
5863

5964

0 commit comments

Comments
 (0)