1- from typing import Any , Type , TypeVar , cast
1+ from typing import Any , TypeVar
22
33from lightning_sdk .lightning_cloud .login import Auth
44from lightning_utilities .core .rank_zero import rank_zero_only
@@ -38,6 +38,7 @@ def _upload_model(self, filepath: str) -> None:
3838
3939# Create specific implementations
4040if _LIGHTNING_AVAILABLE :
41+
4142 class LightningModelCheckpoint (LitModelCheckpointMixin , _LightningModelCheckpoint ):
4243 """Lightning ModelCheckpoint with LitModel support.
4344
@@ -55,7 +56,9 @@ def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
5556 super ()._save_checkpoint (trainer , filepath )
5657 self ._upload_model (filepath )
5758
59+
5860if _PYTORCHLIGHTNING_AVAILABLE :
61+
5962 class PTLightningModelCheckpoint (LitModelCheckpointMixin , _PytorchLightningModelCheckpoint ):
6063 """PyTorch Lightning ModelCheckpoint with LitModel support.
6164
@@ -71,4 +74,4 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None:
7174
7275 def _save_checkpoint (self , trainer : "Trainer" , filepath : str ) -> None :
7376 super ()._save_checkpoint (trainer , filepath )
74- self ._upload_model (filepath )
77+ self ._upload_model (filepath )
0 commit comments