Skip to content

Commit 5fccb1d

Browse files
committed
TYPE_CHECKING
1 parent b160ce9 commit 5fccb1d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/litmodels/integrations/checkpoints.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, TypeVar
1+
from typing import TYPE_CHECKING, Any
22

33
from lightning_sdk.lightning_cloud.login import Auth
44
from lightning_utilities.core.rank_zero import rank_zero_only
@@ -8,11 +8,16 @@
88

99
if _LIGHTNING_AVAILABLE:
1010
from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint
11+
12+
if TYPE_CHECKING:
13+
from lightning.pytorch import Trainer
14+
15+
1116
if _PYTORCHLIGHTNING_AVAILABLE:
1217
from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint
1318

14-
# Type variable for the ModelCheckpoint class
15-
ModelCheckpointType = TypeVar("ModelCheckpointType")
19+
if TYPE_CHECKING:
20+
from pytorch_lightning import Trainer
1621

1722

1823
# Base class to be inherited
@@ -49,6 +54,7 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
4954
"""
5055

5156
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
57+
"""Initialize the checkpoint with model name and other parameters."""
5258
_LightningModelCheckpoint.__init__(self, *args, **kwargs)
5359
LitModelCheckpointMixin.__init__(self, model_name)
5460

@@ -69,6 +75,7 @@ class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModel
6975
"""
7076

7177
def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
78+
"""Initialize the checkpoint with model name and other parameters."""
7279
_PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs)
7380
LitModelCheckpointMixin.__init__(self, model_name)
7481

0 commit comments

Comments
 (0)