1- from typing import TYPE_CHECKING , Any
1+ from abc import ABC
2+ from datetime import datetime
3+ from typing import TYPE_CHECKING , Any , Optional
24
35from lightning_sdk .lightning_cloud .login import Auth
4- from lightning_utilities .core .rank_zero import rank_zero_only
6+ from lightning_utilities .core .rank_zero import rank_zero_only , rank_zero_warn
57
68from litmodels import upload_model
79from litmodels .integrations .imports import _LIGHTNING_AVAILABLE , _PYTORCHLIGHTNING_AVAILABLE
810
911if _LIGHTNING_AVAILABLE :
1012 from lightning .pytorch .callbacks import ModelCheckpoint as _LightningModelCheckpoint
1113
12- if TYPE_CHECKING :
13- from lightning .pytorch import Trainer
14-
1514
1615if _PYTORCHLIGHTNING_AVAILABLE :
1716 from pytorch_lightning .callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint
1817
19- if TYPE_CHECKING :
20- from pytorch_lightning import Trainer
18+
19+ if TYPE_CHECKING :
20+ if _LIGHTNING_AVAILABLE :
21+ import lightning .pytorch as pl
22+ if _PYTORCHLIGHTNING_AVAILABLE :
23+ import pytorch_lightning as pl
2124
2225
2326# Base class to be inherited
24- class LitModelCheckpointMixin :
27+ class LitModelCheckpointMixin ( ABC ) :
2528 """Mixin class for LitModel checkpoint functionality."""
2629
27- def __init__ (self , model_name : str , * args : Any , ** kwargs : Any ) -> None :
30+ # mainly ofr mocking reasons
31+ _datetime_stamp : str = datetime .now ().strftime ("%Y%m%d-%H%M" )
32+ model_name : Optional [str ] = None
33+
34+ def __init__ (self , model_name : Optional [str ]) -> None :
2835 """Initialize with model name."""
36+ if not model_name :
37+ rank_zero_warn (
38+ "The model is not defined so we will continue with LightningModule names and timestamp of now"
39+ )
2940 self .model_name = model_name
3041
3142 try : # authenticate before anything else starts
@@ -38,8 +49,19 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
3849 def _upload_model (self , filepath : str ) -> None :
3950 # todo: uploading on background so training does nt stops
4051 # todo: use filename as version but need to validate that such version does not exists yet
52+ if not self .model_name :
53+ raise RuntimeError (
54+ "Model name is not specified neither updated by `setup` method via Trainer."
55+ " Please set the model name before uploading or ensure that `setup` method is called."
56+ )
4157 upload_model (name = self .model_name , model = filepath )
4258
59+ def _update_model_name (self , pl_model : "pl.LightningModule" ) -> None :
60+ if self .model_name :
61+ return
62+ # setting the model name as Lightning module with some time hash
63+ self .model_name = pl_model .__class__ .__name__ + f"_{ self ._datetime_stamp } "
64+
4365
4466# Create specific implementations
4567if _LIGHTNING_AVAILABLE :
@@ -53,15 +75,20 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
5375 kwargs: Additional keyword arguments to pass to the parent class.
5476 """
5577
56- def __init__ (self , model_name : str , * args : Any , ** kwargs : Any ) -> None :
78+ def __init__ (self , * args : Any , model_name : Optional [ str ] = None , ** kwargs : Any ) -> None :
5779 """Initialize the checkpoint with model name and other parameters."""
5880 _LightningModelCheckpoint .__init__ (self , * args , ** kwargs )
5981 LitModelCheckpointMixin .__init__ (self , model_name )
6082
61- def _save_checkpoint (self , trainer : "Trainer" , filepath : str ) -> None :
83+ def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
84+ """Setup the checkpoint callback."""
85+ super ().setup (trainer , pl_module , stage )
86+ self ._update_model_name (pl_module )
87+
88+ def _save_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
89+ """Extend the save checkpoint method to upload the model."""
6290 super ()._save_checkpoint (trainer , filepath )
63- if trainer .is_global_zero :
64- # Only upload from the main process
91+ if trainer .is_global_zero : # Only upload from the main process
6592 self ._upload_model (filepath )
6693
6794
@@ -76,13 +103,18 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
76103 kwargs: Additional keyword arguments to pass to the parent class.
77104 """
78105
79- def __init__ (self , model_name : str , * args : Any , ** kwargs : Any ) -> None :
106+ def __init__ (self , * args : Any , model_name : Optional [ str ] = None , ** kwargs : Any ) -> None :
80107 """Initialize the checkpoint with model name and other parameters."""
81108 _PytorchLightningModelCheckpoint .__init__ (self , * args , ** kwargs )
82109 LitModelCheckpointMixin .__init__ (self , model_name )
83110
84- def _save_checkpoint (self , trainer : "Trainer" , filepath : str ) -> None :
111+ def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
112+ """Setup the checkpoint callback."""
113+ super ().setup (trainer , pl_module , stage )
114+ self ._update_model_name (pl_module )
115+
116+ def _save_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
117+ """Extend the save checkpoint method to upload the model."""
85118 super ()._save_checkpoint (trainer , filepath )
86- if trainer .is_global_zero :
87- # Only upload from the main process
119+ if trainer .is_global_zero : # Only upload from the main process
88120 self ._upload_model (filepath )
0 commit comments