1- from typing import Any , TypeVar
1+ from typing import TYPE_CHECKING , Any
22
33from lightning_sdk .lightning_cloud .login import Auth
44from lightning_utilities .core .rank_zero import rank_zero_only
88
99if _LIGHTNING_AVAILABLE :
1010 from lightning .pytorch .callbacks import ModelCheckpoint as _LightningModelCheckpoint
11+
12+ if TYPE_CHECKING :
13+ from lightning .pytorch import Trainer
14+
15+
1116if _PYTORCHLIGHTNING_AVAILABLE :
1217 from pytorch_lightning .callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint
1318
14- # Type variable for the ModelCheckpoint class
15- ModelCheckpointType = TypeVar ( "ModelCheckpointType" )
19+ if TYPE_CHECKING :
20+ from pytorch_lightning import Trainer
1621
1722
1823# Base class to be inherited
@@ -49,6 +54,7 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
4954 """
5055
5156 def __init__ (self , model_name : str , * args : Any , ** kwargs : Any ) -> None :
57+ """Initialize the checkpoint with model name and other parameters."""
5258 _LightningModelCheckpoint .__init__ (self , * args , ** kwargs )
5359 LitModelCheckpointMixin .__init__ (self , model_name )
5460
@@ -69,6 +75,7 @@ class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModel
6975 """
7076
7177 def __init__ (self , model_name : str , * args : Any , ** kwargs : Any ) -> None :
78+ """Initialize the checkpoint with model name and other parameters."""
7279 _PytorchLightningModelCheckpoint .__init__ (self , * args , ** kwargs )
7380 LitModelCheckpointMixin .__init__ (self , model_name )
7481
0 commit comments